mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 05:23:11 -05:00
Compare commits
10 Commits
alexcheema
...
v1.0.64
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55b67e2be2 | ||
|
|
df240f834d | ||
|
|
cd125b3b8c | ||
|
|
30cfad9b68 | ||
|
|
b783a21399 | ||
|
|
43f12f5d08 | ||
|
|
8027d7933f | ||
|
|
ac6efa747b | ||
|
|
2e3c33db6d | ||
|
|
fc8e6ad06b |
@@ -14,6 +14,7 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@State private var showAllNodes = false
|
||||
@@ -24,6 +25,8 @@ struct ContentView: View {
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var uninstallInProgress = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
@State private var pendingHFToken: String = ""
|
||||
@State private var pendingEnableImageModels = false
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
@@ -303,6 +306,49 @@ struct ContentView: View {
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
}
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("HuggingFace Token")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
SecureField("optional", text: $pendingHFToken)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingHFToken = controller.hfToken
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.hfToken = pendingHFToken
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingHFToken == controller.hfToken)
|
||||
}
|
||||
}
|
||||
Divider()
|
||||
HStack {
|
||||
Toggle(
|
||||
"Enable Image Models (experimental)", isOn: $pendingEnableImageModels
|
||||
)
|
||||
.toggleStyle(.switch)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingEnableImageModels = controller.enableImageModels
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
Button("Save & Restart") {
|
||||
controller.enableImageModels = pendingEnableImageModels
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingEnableImageModels == controller.enableImageModels)
|
||||
}
|
||||
HoverButton(title: "Check for Updates", small: true) {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
@@ -423,6 +469,44 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
/// Shows TB bridge status for all nodes from exo cluster state
|
||||
private var clusterThunderboltBridgeView: some View {
|
||||
let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]
|
||||
let localNodeId = stateService.localNodeId
|
||||
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
|
||||
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
if bridgeStatuses.isEmpty {
|
||||
Text("Cluster TB Bridge: No data")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
} else {
|
||||
Text("Cluster TB Bridge Status:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(Array(bridgeStatuses.keys.sorted()), id: \.self) { nodeId in
|
||||
if let status = bridgeStatuses[nodeId] {
|
||||
let nodeName =
|
||||
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
|
||||
let isLocal = nodeId == localNodeId
|
||||
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
|
||||
let statusText =
|
||||
!status.exists
|
||||
? "N/A"
|
||||
: (status.enabled ? "Enabled" : "Disabled")
|
||||
let color: Color =
|
||||
!status.exists
|
||||
? .secondary
|
||||
: (status.enabled ? .red : .green)
|
||||
Text("\(prefix) \(statusText)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(color)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var interfaceIpList: some View {
|
||||
let statuses = networkStatusService.status.interfaceStatuses
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
@@ -465,6 +549,7 @@ struct ContentView: View {
|
||||
Text(thunderboltStatusText)
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
clusterThunderboltBridgeView
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
|
||||
@@ -21,6 +21,7 @@ struct EXOApp: App {
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
|
||||
@@ -41,10 +42,13 @@ struct EXOApp: App {
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
|
||||
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
// Remove old LaunchDaemon components if they exist (from previous versions)
|
||||
cleanupLegacyNetworkSetup()
|
||||
// Check local network access periodically (warning disappears when user grants permission)
|
||||
localNetwork.startPeriodicChecking(interval: 10)
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -58,6 +62,7 @@ struct EXOApp: App {
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
.environmentObject(thunderboltBridgeService)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
}
|
||||
@@ -130,6 +135,37 @@ struct EXOApp: App {
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
private func cleanupLegacyNetworkSetup() {
|
||||
guard NetworkSetupHelper.hasInstalledComponents() else { return }
|
||||
// Dispatch async to ensure app is ready before showing alert
|
||||
DispatchQueue.main.async {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to configure local network discovery on your device. This requires granting permission once."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Continue")
|
||||
alert.addButton(withTitle: "Later")
|
||||
|
||||
let response = alert.runModal()
|
||||
guard response == .alertFirstButtonReturn else {
|
||||
Logger().info("User deferred legacy network setup cleanup")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try NetworkSetupHelper.uninstall()
|
||||
Logger().info("Cleaned up legacy network setup components")
|
||||
} catch {
|
||||
// Non-fatal: user may have cancelled admin prompt or cleanup may have
|
||||
// partially succeeded. The app will continue normally.
|
||||
Logger().warning(
|
||||
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
|
||||
@@ -3,6 +3,8 @@ import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
private let hfTokenKey = "EXOHFToken"
|
||||
private let enableImageModelsKey = "EXOEnableImageModels"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
@@ -37,6 +39,22 @@ final class ExoProcessController: ObservableObject {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
@Published var hfToken: String = {
|
||||
return UserDefaults.standard.string(forKey: hfTokenKey) ?? ""
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(hfToken, forKey: hfTokenKey)
|
||||
}
|
||||
}
|
||||
@Published var enableImageModels: Bool = {
|
||||
return UserDefaults.standard.bool(forKey: enableImageModelsKey)
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -191,6 +209,12 @@ final class ExoProcessController: ObservableObject {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
if !hfToken.isEmpty {
|
||||
environment["HF_TOKEN"] = hfToken
|
||||
}
|
||||
if enableImageModels {
|
||||
environment["EXO_ENABLE_IMAGE_MODELS"] = "true"
|
||||
}
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
|
||||
@@ -5,17 +5,43 @@ import Foundation
|
||||
struct ClusterState: Decodable {
|
||||
let instances: [String: ClusterInstance]
|
||||
let runners: [String: RunnerStatusSummary]
|
||||
let nodeProfiles: [String: NodeProfile]
|
||||
let tasks: [String: ClusterTask]
|
||||
let topology: Topology?
|
||||
let downloads: [String: [NodeDownloadStatus]]
|
||||
let thunderboltBridgeCycles: [[String]]
|
||||
|
||||
// Granular node state (split from the old nodeProfiles)
|
||||
let nodeIdentities: [String: NodeIdentity]
|
||||
let nodeMemory: [String: MemoryInfo]
|
||||
let nodeSystem: [String: SystemInfo]
|
||||
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
|
||||
|
||||
/// Computed property for backwards compatibility - merges granular state into NodeProfile
|
||||
var nodeProfiles: [String: NodeProfile] {
|
||||
var profiles: [String: NodeProfile] = [:]
|
||||
let allNodeIds = Set(nodeIdentities.keys)
|
||||
.union(nodeMemory.keys)
|
||||
.union(nodeSystem.keys)
|
||||
for nodeId in allNodeIds {
|
||||
let identity = nodeIdentities[nodeId]
|
||||
let memory = nodeMemory[nodeId]
|
||||
let system = nodeSystem[nodeId]
|
||||
profiles[nodeId] = NodeProfile(
|
||||
modelId: identity?.modelId,
|
||||
chipId: identity?.chipId,
|
||||
friendlyName: identity?.friendlyName,
|
||||
memory: memory,
|
||||
system: system
|
||||
)
|
||||
}
|
||||
return profiles
|
||||
}
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
@@ -24,15 +50,34 @@ struct ClusterState: Decodable {
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
self.thunderboltBridgeCycles =
|
||||
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
|
||||
|
||||
// Granular node state
|
||||
self.nodeIdentities =
|
||||
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
|
||||
?? [:]
|
||||
self.nodeMemory =
|
||||
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
|
||||
self.nodeSystem =
|
||||
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
|
||||
self.nodeThunderboltBridge =
|
||||
try container.decodeIfPresent(
|
||||
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
|
||||
) ?? [:]
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case instances
|
||||
case runners
|
||||
case nodeProfiles
|
||||
case topology
|
||||
case tasks
|
||||
case downloads
|
||||
case thunderboltBridgeCycles
|
||||
case nodeIdentities
|
||||
case nodeMemory
|
||||
case nodeSystem
|
||||
case nodeThunderboltBridge
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +147,18 @@ struct NodeProfile: Decodable {
|
||||
let system: SystemInfo?
|
||||
}
|
||||
|
||||
struct NodeIdentity: Decodable {
|
||||
let modelId: String?
|
||||
let chipId: String?
|
||||
let friendlyName: String?
|
||||
}
|
||||
|
||||
struct ThunderboltBridgeStatus: Decodable {
|
||||
let enabled: Bool
|
||||
let exists: Bool
|
||||
let serviceName: String?
|
||||
}
|
||||
|
||||
struct MemoryInfo: Decodable {
|
||||
let ramTotal: MemoryValue?
|
||||
let ramAvailable: MemoryValue?
|
||||
@@ -120,16 +177,51 @@ struct SystemInfo: Decodable {
|
||||
}
|
||||
|
||||
struct Topology: Decodable {
|
||||
let nodes: [TopologyNode]
|
||||
let connections: [TopologyConnection]?
|
||||
/// Node IDs in the topology
|
||||
let nodes: [String]
|
||||
/// Flattened list of connections (source -> sink pairs)
|
||||
let connections: [TopologyConnection]
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
|
||||
|
||||
// Connections come as nested map: { source: { sink: [edges] } }
|
||||
// We flatten to array of (source, sink) pairs
|
||||
var flatConnections: [TopologyConnection] = []
|
||||
if let nested = try container.decodeIfPresent(
|
||||
[String: [String: [AnyCodable]]].self, forKey: .connections
|
||||
) {
|
||||
for (source, sinks) in nested {
|
||||
for sink in sinks.keys {
|
||||
flatConnections.append(
|
||||
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
|
||||
}
|
||||
}
|
||||
}
|
||||
self.connections = flatConnections
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case nodes
|
||||
case connections
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyNode: Decodable {
|
||||
let nodeId: String
|
||||
let nodeProfile: NodeProfile
|
||||
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
|
||||
private struct AnyCodable: Decodable {
|
||||
init(from decoder: Decoder) throws {
|
||||
// Just consume the value without storing it
|
||||
_ = try? decoder.singleValueContainer().decode(Bool.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Int.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Double.self)
|
||||
_ = try? decoder.singleValueContainer().decode(String.self)
|
||||
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
|
||||
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyConnection: Decodable {
|
||||
struct TopologyConnection {
|
||||
let localNodeId: String
|
||||
let sendBackNodeId: String
|
||||
}
|
||||
|
||||
@@ -55,12 +55,16 @@ struct BugReportService {
|
||||
let stateData = try await stateResult
|
||||
let eventsData = try await eventsResult
|
||||
|
||||
// Extract cluster TB bridge status from exo state
|
||||
let clusterTbBridgeStatus = extractClusterTbBridgeStatus(from: stateData)
|
||||
|
||||
let reportJSON = makeReportJson(
|
||||
timestamp: timestamp,
|
||||
hostName: hostName,
|
||||
ifconfig: ifconfigText,
|
||||
debugInfo: debugInfo,
|
||||
isManual: isManual
|
||||
isManual: isManual,
|
||||
clusterTbBridgeStatus: clusterTbBridgeStatus
|
||||
)
|
||||
|
||||
let uploads: [(path: String, data: Data?)] = [
|
||||
@@ -178,18 +182,19 @@ struct BugReportService {
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeDisabled() -> Bool? {
|
||||
let result = runCommand([
|
||||
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
|
||||
])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
return false
|
||||
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
|
||||
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
|
||||
// No bridge containing Thunderbolt interfaces exists
|
||||
return nil
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return true
|
||||
|
||||
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
// Return true if disabled, false if enabled
|
||||
return !isEnabled
|
||||
}
|
||||
|
||||
private func readInterfaces() -> [DebugInfo.InterfaceStatus] {
|
||||
@@ -268,11 +273,12 @@ struct BugReportService {
|
||||
hostName: String,
|
||||
ifconfig: String,
|
||||
debugInfo: DebugInfo,
|
||||
isManual: Bool
|
||||
isManual: Bool,
|
||||
clusterTbBridgeStatus: [[String: Any]]?
|
||||
) -> Data? {
|
||||
let system = readSystemMetadata()
|
||||
let exo = readExoMetadata()
|
||||
let payload: [String: Any] = [
|
||||
var payload: [String: Any] = [
|
||||
"timestamp": timestamp,
|
||||
"host": hostName,
|
||||
"ifconfig": ifconfig,
|
||||
@@ -282,9 +288,38 @@ struct BugReportService {
|
||||
"exo_commit": exo.commit as Any,
|
||||
"report_type": isManual ? "manual" : "automated",
|
||||
]
|
||||
if let tbStatus = clusterTbBridgeStatus {
|
||||
payload["cluster_thunderbolt_bridge"] = tbStatus
|
||||
}
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
|
||||
/// Extracts cluster-wide Thunderbolt Bridge status from exo state JSON
|
||||
private func extractClusterTbBridgeStatus(from stateData: Data?) -> [[String: Any]]? {
|
||||
guard let data = stateData,
|
||||
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||
let nodeThunderboltBridge = json["node_thunderbolt_bridge"] as? [String: [String: Any]]
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result: [[String: Any]] = []
|
||||
for (nodeId, status) in nodeThunderboltBridge {
|
||||
var entry: [String: Any] = ["node_id": nodeId]
|
||||
if let enabled = status["enabled"] as? Bool {
|
||||
entry["enabled"] = enabled
|
||||
}
|
||||
if let exists = status["exists"] as? Bool {
|
||||
entry["exists"] = exists
|
||||
}
|
||||
if let serviceName = status["service_name"] as? String {
|
||||
entry["service_name"] = serviceName
|
||||
}
|
||||
result.append(entry)
|
||||
}
|
||||
return result.isEmpty ? nil : result
|
||||
}
|
||||
|
||||
private func readSystemMetadata() -> [String: Any] {
|
||||
let hostname = safeRunCommand(["/bin/hostname"])
|
||||
let computerName = safeRunCommand(["/usr/sbin/scutil", "--get", "ComputerName"])
|
||||
|
||||
@@ -41,6 +41,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
private var periodicTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
@@ -48,10 +49,39 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
/// Checks if local network access is working (one-time check).
|
||||
func check() {
|
||||
performCheck()
|
||||
}
|
||||
|
||||
/// Starts periodic checking of local network access.
|
||||
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
|
||||
func startPeriodicChecking(interval: TimeInterval = 10) {
|
||||
stopPeriodicChecking()
|
||||
// Do an immediate check first
|
||||
performCheck()
|
||||
// Then schedule periodic checks
|
||||
periodicTask = Task { [weak self] in
|
||||
while !Task.isCancelled {
|
||||
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
|
||||
guard !Task.isCancelled else { break }
|
||||
self?.performCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops periodic checking.
|
||||
func stopPeriodicChecking() {
|
||||
periodicTask?.cancel()
|
||||
periodicTask = nil
|
||||
}
|
||||
|
||||
private func performCheck() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
// Only show "checking" status on first check to avoid UI flicker
|
||||
if status == .unknown {
|
||||
status = .checking
|
||||
}
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
@@ -60,12 +90,15 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
// Only log on state changes or first check to reduce noise
|
||||
if isFirstCheck || result != self.status {
|
||||
Self.logger.info("Local network check: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +174,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
func stop() {
|
||||
stopPeriodicChecking()
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
|
||||
@@ -7,48 +7,10 @@ enum NetworkSetupHelper {
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge.sh"
|
||||
// Legacy script path from older versions
|
||||
private static let legacyScriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
do {
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
}
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
@@ -63,8 +25,9 @@ enum NetworkSetupHelper {
|
||||
static func hasInstalledComponents() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
return scriptExists || plistExists
|
||||
return scriptExists || legacyScriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
@@ -73,6 +36,7 @@ enum NetworkSetupHelper {
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
@@ -83,8 +47,9 @@ enum NetworkSetupHelper {
|
||||
# Remove LaunchDaemon plist
|
||||
rm -f "$PLIST_DEST"
|
||||
|
||||
# Remove the script and parent directory if empty
|
||||
# Remove the script (current and legacy paths) and parent directory if empty
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rm -f "$LEGACY_SCRIPT_DEST"
|
||||
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
|
||||
|
||||
# Remove log files
|
||||
@@ -98,99 +63,42 @@ enum NetworkSetupHelper {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
} || true
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
# We find it dynamically by looking for bridges containing Thunderbolt interfaces
|
||||
find_and_enable_thunderbolt_bridge() {
|
||||
# Get Thunderbolt interface devices from hardware ports
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
')
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
if echo "$members" | grep -qx "$tb_dev"; then
|
||||
# Find the service name for this bridge device
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
')
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
|
||||
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let interval = plist["StartInterval"] as? Int,
|
||||
interval == requiredStartInterval
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private static func installLaunchDaemon() async throws {
|
||||
let installerScript = makeInstallerScript()
|
||||
try runShellAsAdmin(installerScript)
|
||||
}
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript =
|
||||
script
|
||||
|
||||
@@ -153,22 +153,18 @@ private struct NetworkStatusFetcher {
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeState() -> ThunderboltState? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
guard result.exitCode == 0 else {
|
||||
let lower = result.output.lowercased() + result.error.lowercased()
|
||||
if lower.contains("not a recognized network service") {
|
||||
return .deleted
|
||||
}
|
||||
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
|
||||
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
|
||||
// No bridge containing Thunderbolt interfaces exists
|
||||
return .deleted
|
||||
}
|
||||
|
||||
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
return .enabled
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return .disabled
|
||||
}
|
||||
return nil
|
||||
|
||||
return isEnabled ? .enabled : .disabled
|
||||
}
|
||||
|
||||
private func readBridgeInactive() -> Bool? {
|
||||
|
||||
194
app/EXO/EXO/Services/ThunderboltBridgeDetector.swift
Normal file
194
app/EXO/EXO/Services/ThunderboltBridgeDetector.swift
Normal file
@@ -0,0 +1,194 @@
|
||||
import Foundation
|
||||
import os.log
|
||||
|
||||
/// Utility for dynamically detecting Thunderbolt Bridge network services.
|
||||
/// This mirrors the Python logic in info_gatherer.py - we never assume the service
|
||||
/// is named "Thunderbolt Bridge", instead we find bridges containing Thunderbolt interfaces.
|
||||
enum ThunderboltBridgeDetector {
|
||||
private static let logger = Logger(
|
||||
subsystem: "io.exo.EXO", category: "ThunderboltBridgeDetector")
|
||||
|
||||
struct CommandResult {
|
||||
let exitCode: Int32
|
||||
let output: String
|
||||
let error: String
|
||||
}
|
||||
|
||||
/// Find the network service name of a bridge containing Thunderbolt interfaces.
|
||||
/// Returns nil if no such bridge exists.
|
||||
static func findThunderboltBridgeServiceName() -> String? {
|
||||
// 1. Get all Thunderbolt interface devices (e.g., en2, en3)
|
||||
guard let thunderboltDevices = getThunderboltDevices(), !thunderboltDevices.isEmpty else {
|
||||
logger.debug("No Thunderbolt devices found")
|
||||
return nil
|
||||
}
|
||||
logger.debug("Found Thunderbolt devices: \(thunderboltDevices.joined(separator: ", "))")
|
||||
|
||||
// 2. Get bridge services from network service order
|
||||
guard let bridgeServices = getBridgeServices(), !bridgeServices.isEmpty else {
|
||||
logger.debug("No bridge services found")
|
||||
return nil
|
||||
}
|
||||
logger.debug("Found bridge services: \(bridgeServices.keys.joined(separator: ", "))")
|
||||
|
||||
// 3. Find a bridge that contains Thunderbolt interfaces
|
||||
for (bridgeDevice, serviceName) in bridgeServices {
|
||||
let members = getBridgeMembers(bridgeDevice: bridgeDevice)
|
||||
logger.debug(
|
||||
"Bridge \(bridgeDevice) (\(serviceName)) has members: \(members.joined(separator: ", "))"
|
||||
)
|
||||
|
||||
// Check if any Thunderbolt device is a member of this bridge
|
||||
if !members.isDisjoint(with: thunderboltDevices) {
|
||||
logger.info(
|
||||
"Found Thunderbolt Bridge service: '\(serviceName)' (device: \(bridgeDevice))")
|
||||
return serviceName
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("No bridge found containing Thunderbolt interfaces")
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
|
||||
private static func getThunderboltDevices() -> Set<String>? {
|
||||
let result = runCommand(["networksetup", "-listallhardwareports"])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("networksetup -listallhardwareports failed: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
var thunderboltDevices: Set<String> = []
|
||||
var currentPort: String?
|
||||
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("Hardware Port:") {
|
||||
currentPort = String(trimmed.dropFirst("Hardware Port:".count)).trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("Device:"), let port = currentPort {
|
||||
let device = String(trimmed.dropFirst("Device:".count)).trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
if port.lowercased().contains("thunderbolt") {
|
||||
thunderboltDevices.insert(device)
|
||||
}
|
||||
currentPort = nil
|
||||
}
|
||||
}
|
||||
|
||||
return thunderboltDevices
|
||||
}
|
||||
|
||||
/// Get mapping of bridge device -> service name from network service order.
|
||||
private static func getBridgeServices() -> [String: String]? {
|
||||
let result = runCommand(["networksetup", "-listnetworkserviceorder"])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("networksetup -listnetworkserviceorder failed: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse service order to find bridge devices and their service names
|
||||
// Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
|
||||
var bridgeServices: [String: String] = [:]
|
||||
var currentService: String?
|
||||
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
// Match "(N) Service Name" or "(*) Service Name" (disabled)
|
||||
// but NOT "(Hardware Port: ...)" lines
|
||||
if trimmed.hasPrefix("("), trimmed.contains(")"),
|
||||
!trimmed.hasPrefix("(Hardware Port:")
|
||||
{
|
||||
if let parenEnd = trimmed.firstIndex(of: ")") {
|
||||
let afterParen = trimmed.index(after: parenEnd)
|
||||
if afterParen < trimmed.endIndex {
|
||||
currentService =
|
||||
String(trimmed[afterParen...])
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Match "(Hardware Port: ..., Device: bridgeX)"
|
||||
else if let service = currentService, trimmed.contains("Device: bridge") {
|
||||
// Extract device name from "..., Device: bridge0)"
|
||||
if let deviceRange = trimmed.range(of: "Device: ") {
|
||||
let afterDevice = trimmed[deviceRange.upperBound...]
|
||||
if let parenIndex = afterDevice.firstIndex(of: ")") {
|
||||
let device = String(afterDevice[..<parenIndex])
|
||||
bridgeServices[device] = service
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bridgeServices
|
||||
}
|
||||
|
||||
/// Get member interfaces of a bridge device via ifconfig.
|
||||
private static func getBridgeMembers(bridgeDevice: String) -> Set<String> {
|
||||
let result = runCommand(["ifconfig", bridgeDevice])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.debug("ifconfig \(bridgeDevice) failed")
|
||||
return []
|
||||
}
|
||||
|
||||
var members: Set<String> = []
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("member:") {
|
||||
let parts = trimmed.split(separator: " ")
|
||||
if parts.count > 1 {
|
||||
members.insert(String(parts[1]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return members
|
||||
}
|
||||
|
||||
/// Check if a network service is enabled.
|
||||
static func isServiceEnabled(serviceName: String) -> Bool? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", serviceName])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("Failed to check if '\(serviceName)' is enabled: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private static func runCommand(_ arguments: [String]) -> CommandResult {
|
||||
let process = Process()
|
||||
process.launchPath = "/usr/bin/env"
|
||||
process.arguments = arguments
|
||||
|
||||
let stdout = Pipe()
|
||||
let stderr = Pipe()
|
||||
process.standardOutput = stdout
|
||||
process.standardError = stderr
|
||||
|
||||
do {
|
||||
try process.run()
|
||||
} catch {
|
||||
return CommandResult(exitCode: -1, output: "", error: error.localizedDescription)
|
||||
}
|
||||
process.waitUntilExit()
|
||||
|
||||
let outputData = stdout.fileHandleForReading.readDataToEndOfFile()
|
||||
let errorData = stderr.fileHandleForReading.readDataToEndOfFile()
|
||||
|
||||
return CommandResult(
|
||||
exitCode: process.terminationStatus,
|
||||
output: String(decoding: outputData, as: UTF8.self),
|
||||
error: String(decoding: errorData, as: UTF8.self)
|
||||
)
|
||||
}
|
||||
}
|
||||
258
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
258
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
@@ -0,0 +1,258 @@
|
||||
import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
import Security
|
||||
import SystemConfiguration
|
||||
import os.log
|
||||
|
||||
@MainActor
|
||||
final class ThunderboltBridgeService: ObservableObject {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
|
||||
|
||||
@Published private(set) var detectedCycle: [String]?
|
||||
@Published private(set) var hasPromptedForCurrentCycle = false
|
||||
@Published private(set) var lastError: String?
|
||||
|
||||
private weak var clusterStateService: ClusterStateService?
|
||||
private var cancellables = Set<AnyCancellable>()
|
||||
private var previousCycleSignature: String?
|
||||
|
||||
init(clusterStateService: ClusterStateService) {
|
||||
self.clusterStateService = clusterStateService
|
||||
setupObserver()
|
||||
}
|
||||
|
||||
private func setupObserver() {
|
||||
guard let service = clusterStateService else { return }
|
||||
|
||||
service.$latestSnapshot
|
||||
.compactMap { $0 }
|
||||
.sink { [weak self] snapshot in
|
||||
self?.checkForCycles(snapshot: snapshot)
|
||||
}
|
||||
.store(in: &cancellables)
|
||||
}
|
||||
|
||||
private func checkForCycles(snapshot: ClusterState) {
|
||||
let cycles = snapshot.thunderboltBridgeCycles
|
||||
|
||||
// Only consider cycles with more than 2 nodes
|
||||
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
|
||||
// No problematic cycles detected, reset state
|
||||
if detectedCycle != nil {
|
||||
detectedCycle = nil
|
||||
previousCycleSignature = nil
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Create a signature for this cycle to detect if it changed
|
||||
let cycleSignature = firstCycle.sorted().joined(separator: ",")
|
||||
|
||||
// If this is a new/different cycle, reset the prompt state
|
||||
if cycleSignature != previousCycleSignature {
|
||||
previousCycleSignature = cycleSignature
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
|
||||
detectedCycle = firstCycle
|
||||
|
||||
// Only prompt once per cycle
|
||||
if !hasPromptedForCurrentCycle {
|
||||
showDisableBridgePrompt(nodeIds: firstCycle)
|
||||
}
|
||||
}
|
||||
|
||||
private func showDisableBridgePrompt(nodeIds: [String]) {
|
||||
hasPromptedForCurrentCycle = true
|
||||
|
||||
// Get friendly names for the nodes if available
|
||||
let nodeNames = nodeIds.map { nodeId -> String in
|
||||
if let snapshot = clusterStateService?.latestSnapshot,
|
||||
let profile = snapshot.nodeProfiles[nodeId],
|
||||
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
|
||||
{
|
||||
return friendlyName
|
||||
}
|
||||
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
|
||||
}
|
||||
let machineNames = nodeNames.joined(separator: ", ")
|
||||
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Thunderbolt Bridge Loop Detected"
|
||||
alert.informativeText = """
|
||||
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
|
||||
|
||||
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Disable Bridge")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
|
||||
let response = alert.runModal()
|
||||
|
||||
if response == .alertFirstButtonReturn {
|
||||
Task {
|
||||
await disableThunderboltBridge()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func disableThunderboltBridge() async {
|
||||
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
|
||||
lastError = nil
|
||||
|
||||
do {
|
||||
try await disableThunderboltBridgeWithSCPreferences()
|
||||
Self.logger.info("Successfully disabled Thunderbolt Bridge")
|
||||
} catch {
|
||||
Self.logger.error(
|
||||
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
lastError = error.localizedDescription
|
||||
showErrorAlert(message: error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
private func disableThunderboltBridgeWithSCPreferences() async throws {
|
||||
// 1. Create authorization reference
|
||||
var authRef: AuthorizationRef?
|
||||
var status = AuthorizationCreate(nil, nil, [], &authRef)
|
||||
guard status == errAuthorizationSuccess, let authRef = authRef else {
|
||||
throw ThunderboltBridgeError.authorizationFailed
|
||||
}
|
||||
|
||||
defer { AuthorizationFree(authRef, [.destroyRights]) }
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
}
|
||||
throw ThunderboltBridgeError.authorizationDenied
|
||||
}
|
||||
|
||||
// 3. Create SCPreferences with authorization
|
||||
guard
|
||||
let prefs = SCPreferencesCreateWithAuthorization(
|
||||
kCFAllocatorDefault,
|
||||
"EXO" as CFString,
|
||||
nil,
|
||||
authRef
|
||||
)
|
||||
else {
|
||||
throw ThunderboltBridgeError.preferencesCreationFailed
|
||||
}
|
||||
|
||||
// 4. Lock, modify, commit
|
||||
guard SCPreferencesLock(prefs, true) else {
|
||||
throw ThunderboltBridgeError.lockFailed
|
||||
}
|
||||
|
||||
defer {
|
||||
SCPreferencesUnlock(prefs)
|
||||
}
|
||||
|
||||
// 5. Find the Thunderbolt Bridge service dynamically (don't assume the name)
|
||||
guard let targetServiceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName()
|
||||
else {
|
||||
throw ThunderboltBridgeError.serviceNotFound
|
||||
}
|
||||
|
||||
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
|
||||
throw ThunderboltBridgeError.servicesNotFound
|
||||
}
|
||||
|
||||
var found = false
|
||||
for service in allServices {
|
||||
if let name = SCNetworkServiceGetName(service) as String?,
|
||||
name == targetServiceName
|
||||
{
|
||||
guard SCNetworkServiceSetEnabled(service, false) else {
|
||||
throw ThunderboltBridgeError.disableFailed
|
||||
}
|
||||
found = true
|
||||
Self.logger.info(
|
||||
"Found and disabled Thunderbolt Bridge service: '\(targetServiceName)'")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
throw ThunderboltBridgeError.serviceNotFound
|
||||
}
|
||||
|
||||
// 6. Commit and apply
|
||||
guard SCPreferencesCommitChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.commitFailed
|
||||
}
|
||||
|
||||
guard SCPreferencesApplyChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.applyFailed
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Failed to Disable Thunderbolt Bridge"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
}
|
||||
|
||||
enum ThunderboltBridgeError: LocalizedError {
|
||||
case authorizationFailed
|
||||
case authorizationCanceled
|
||||
case authorizationDenied
|
||||
case preferencesCreationFailed
|
||||
case lockFailed
|
||||
case servicesNotFound
|
||||
case serviceNotFound
|
||||
case disableFailed
|
||||
case commitFailed
|
||||
case applyFailed
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .authorizationFailed:
|
||||
return "Failed to create authorization"
|
||||
case .authorizationCanceled:
|
||||
return "Authorization was canceled by user"
|
||||
case .authorizationDenied:
|
||||
return "Authorization was denied"
|
||||
case .preferencesCreationFailed:
|
||||
return "Failed to access network preferences"
|
||||
case .lockFailed:
|
||||
return "Failed to lock network preferences for modification"
|
||||
case .servicesNotFound:
|
||||
return "Could not retrieve network services"
|
||||
case .serviceNotFound:
|
||||
return "Thunderbolt Bridge service not found"
|
||||
case .disableFailed:
|
||||
return "Failed to disable Thunderbolt Bridge service"
|
||||
case .commitFailed:
|
||||
return "Failed to save network configuration changes"
|
||||
case .applyFailed:
|
||||
return "Failed to apply network configuration changes"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ struct TopologyViewModel {
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let topologyNodeIds = Set(topology?.nodes ?? [])
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
@@ -95,8 +95,8 @@ extension ClusterState {
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
var orderedNodes: [NodeViewModel] = []
|
||||
if let topologyNodes = topology?.nodes {
|
||||
for topoNode in topologyNodes {
|
||||
if let viewModel = nodesById[topoNode.nodeId] {
|
||||
for nodeId in topologyNodes {
|
||||
if let viewModel = nodesById[nodeId] {
|
||||
orderedNodes.append(viewModel)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ extension ClusterState {
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
topology?.connections.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,6 +865,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -904,6 +905,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1520,6 +1522,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1529,6 +1532,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,6 +1945,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2648,6 +2653,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2690,6 +2696,7 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2862,6 +2869,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3006,6 +3014,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3027,6 +3036,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -5,22 +5,32 @@
|
||||
topologyData,
|
||||
isTopologyMinimized,
|
||||
debugMode,
|
||||
nodeThunderboltBridge,
|
||||
type NodeInfo,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
highlightedNodes?: Set<string>;
|
||||
filteredNodes?: Set<string>;
|
||||
onNodeClick?: (nodeId: string) => void;
|
||||
}
|
||||
|
||||
let { class: className = "", highlightedNodes = new Set() }: Props = $props();
|
||||
let {
|
||||
class: className = "",
|
||||
highlightedNodes = new Set(),
|
||||
filteredNodes = new Set(),
|
||||
onNodeClick,
|
||||
}: Props = $props();
|
||||
|
||||
let svgContainer: SVGSVGElement | undefined = $state();
|
||||
let resizeObserver: ResizeObserver | undefined;
|
||||
let hoveredNodeId = $state<string | null>(null);
|
||||
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -522,10 +532,72 @@
|
||||
}
|
||||
}
|
||||
|
||||
let iconBaseWidth = nodeRadius * 1.2;
|
||||
let iconBaseHeight = nodeRadius * 1.0;
|
||||
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
|
||||
|
||||
const modelLower = modelId.toLowerCase();
|
||||
|
||||
// Check node states for styling
|
||||
const isHighlighted = highlightedNodes.has(nodeInfo.id);
|
||||
const isInFilter =
|
||||
filteredNodes.size > 0 && filteredNodes.has(nodeInfo.id);
|
||||
const isFilteredOut =
|
||||
filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id);
|
||||
const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter;
|
||||
|
||||
// Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out
|
||||
const wireColor = isInFilter
|
||||
? "rgba(255,215,0,1)" // Bright yellow for filter selection
|
||||
: isHovered
|
||||
? "rgba(255,215,0,0.7)" // Subtle yellow for hover
|
||||
: isHighlighted
|
||||
? "rgba(255,215,0,0.9)" // Yellow for instance highlight
|
||||
: isFilteredOut
|
||||
? "rgba(140,140,140,0.6)" // Grey for filtered out
|
||||
: "rgba(179,179,179,0.8)"; // Default
|
||||
const wireColorBright = "rgba(255,255,255,0.9)";
|
||||
const fillColor = isInFilter
|
||||
? "rgba(255,215,0,0.25)"
|
||||
: isHovered
|
||||
? "rgba(255,215,0,0.12)"
|
||||
: isHighlighted
|
||||
? "rgba(255,215,0,0.15)"
|
||||
: "rgba(255,215,0,0.08)";
|
||||
const strokeWidth = isInFilter
|
||||
? 3
|
||||
: isHovered
|
||||
? 2
|
||||
: isHighlighted
|
||||
? 2.5
|
||||
: 1.5;
|
||||
const screenFill = "rgba(0,20,40,0.9)";
|
||||
const glowColor = "rgba(255,215,0,0.3)";
|
||||
|
||||
const nodeG = nodesGroup
|
||||
.append("g")
|
||||
.attr("class", "graph-node")
|
||||
.style("cursor", "pointer");
|
||||
.style("cursor", onNodeClick ? "pointer" : "default")
|
||||
.style("opacity", isFilteredOut ? 0.5 : 1);
|
||||
|
||||
// Add click and hover handlers - hover just updates state, styling is applied during render
|
||||
nodeG
|
||||
.on("click", (event: MouseEvent) => {
|
||||
if (onNodeClick) {
|
||||
event.stopPropagation();
|
||||
onNodeClick(nodeInfo.id);
|
||||
}
|
||||
})
|
||||
.on("mouseenter", () => {
|
||||
if (onNodeClick) {
|
||||
hoveredNodeId = nodeInfo.id;
|
||||
}
|
||||
})
|
||||
.on("mouseleave", () => {
|
||||
if (hoveredNodeId === nodeInfo.id) {
|
||||
hoveredNodeId = null;
|
||||
}
|
||||
});
|
||||
|
||||
// Add tooltip
|
||||
nodeG
|
||||
@@ -534,27 +606,6 @@
|
||||
`${friendlyName}\nID: ${nodeInfo.id.slice(-8)}\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`,
|
||||
);
|
||||
|
||||
let iconBaseWidth = nodeRadius * 1.2;
|
||||
let iconBaseHeight = nodeRadius * 1.0;
|
||||
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
|
||||
|
||||
const modelLower = modelId.toLowerCase();
|
||||
|
||||
// Check if this node should be highlighted (from hovered instance)
|
||||
const isHighlighted = highlightedNodes.has(nodeInfo.id);
|
||||
|
||||
// Holographic wireframe colors - yellow border when highlighted
|
||||
const wireColor = isHighlighted
|
||||
? "rgba(255,215,0,0.9)"
|
||||
: "rgba(179,179,179,0.8)";
|
||||
const wireColorBright = "rgba(255,255,255,0.9)";
|
||||
const fillColor = isHighlighted
|
||||
? "rgba(255,215,0,0.15)"
|
||||
: "rgba(255,215,0,0.08)";
|
||||
const strokeWidth = isHighlighted ? 2.5 : 1.5;
|
||||
const screenFill = "rgba(0,20,40,0.9)";
|
||||
const glowColor = "rgba(255,215,0,0.3)";
|
||||
|
||||
if (modelLower === "mac studio") {
|
||||
// Mac Studio - classic cube with memory fill
|
||||
iconBaseWidth = nodeRadius * 1.25;
|
||||
@@ -579,6 +630,7 @@
|
||||
// Main body (uniform color)
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", x)
|
||||
.attr("y", y)
|
||||
.attr("width", iconBaseWidth)
|
||||
@@ -661,6 +713,7 @@
|
||||
// Main body (uniform color)
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", x)
|
||||
.attr("y", y)
|
||||
.attr("width", iconBaseWidth)
|
||||
@@ -738,6 +791,7 @@
|
||||
// Screen outer frame
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", screenX)
|
||||
.attr("y", y)
|
||||
.attr("width", screenWidth)
|
||||
@@ -846,6 +900,7 @@
|
||||
// Main shape
|
||||
nodeG
|
||||
.append("polygon")
|
||||
.attr("class", "node-outline")
|
||||
.attr("points", hexPoints)
|
||||
.attr("fill", fillColor)
|
||||
.attr("stroke", wireColor)
|
||||
@@ -1064,11 +1119,41 @@
|
||||
.attr("fill", "rgba(179,179,179,0.7)")
|
||||
.text(` (${ramUsagePercent.toFixed(0)}%)`);
|
||||
}
|
||||
|
||||
// Debug mode: Show TB bridge status
|
||||
if (debugEnabled) {
|
||||
const tbStatus = tbBridgeData[nodeInfo.id];
|
||||
if (tbStatus) {
|
||||
const tbY =
|
||||
nodeInfo.y +
|
||||
iconBaseHeight / 2 +
|
||||
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
|
||||
const tbFontSize = showFullLabels ? 9 : 7;
|
||||
const tbColor = tbStatus.enabled
|
||||
? "rgba(234,179,8,0.9)"
|
||||
: "rgba(100,100,100,0.7)";
|
||||
const tbText = tbStatus.enabled ? "TB:ON" : "TB:OFF";
|
||||
nodeG
|
||||
.append("text")
|
||||
.attr("x", nodeInfo.x)
|
||||
.attr("y", tbY)
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("fill", tbColor)
|
||||
.attr("font-size", tbFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(tbText);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (data) {
|
||||
// Track all reactive dependencies that affect rendering
|
||||
const _data = data;
|
||||
const _hoveredNodeId = hoveredNodeId;
|
||||
const _filteredNodes = filteredNodes;
|
||||
const _highlightedNodes = highlightedNodes;
|
||||
if (_data) {
|
||||
renderGraph();
|
||||
}
|
||||
});
|
||||
@@ -1091,12 +1176,8 @@
|
||||
|
||||
<style>
|
||||
:global(.graph-node) {
|
||||
transition:
|
||||
transform 0.2s ease,
|
||||
opacity 0.2s ease;
|
||||
}
|
||||
:global(.graph-node:hover) {
|
||||
filter: brightness(1.1);
|
||||
/* Only transition opacity for filtered-out nodes, no transition on hover stroke changes */
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
:global(.graph-link) {
|
||||
stroke: var(--exo-light-gray, #b3b3b3);
|
||||
|
||||
@@ -190,6 +190,13 @@ interface RawStateResponse {
|
||||
nodeMemory?: Record<string, RawMemoryUsage>;
|
||||
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
|
||||
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
|
||||
// Thunderbolt bridge status per node
|
||||
nodeThunderboltBridge?: Record<
|
||||
string,
|
||||
{ enabled: boolean; exists: boolean; serviceName?: string | null }
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -419,7 +426,15 @@ class AppStore {
|
||||
placementPreviews = $state<PlacementPreview[]>([]);
|
||||
selectedPreviewModelId = $state<string | null>(null);
|
||||
isLoadingPreviews = $state(false);
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderboltBridge = $state<
|
||||
Record<
|
||||
string,
|
||||
{ enabled: boolean; exists: boolean; serviceName?: string | null }
|
||||
>
|
||||
>({});
|
||||
|
||||
// UI state
|
||||
isTopologyMinimized = $state(false);
|
||||
@@ -439,6 +454,7 @@ class AppStore {
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private previousNodeIds: Set<string> = new Set();
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -997,6 +1013,8 @@ class AppStore {
|
||||
nodeSystem: data.nodeSystem,
|
||||
nodeNetwork: data.nodeNetwork,
|
||||
});
|
||||
// Handle topology changes for preview filter
|
||||
this.handleTopologyChange();
|
||||
}
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
@@ -1008,6 +1026,10 @@ class AppStore {
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
this.nodeThunderboltBridge = data.nodeThunderboltBridge ?? {};
|
||||
this.lastUpdate = Date.now();
|
||||
} catch (error) {
|
||||
console.error("Error fetching state:", error);
|
||||
@@ -1023,9 +1045,14 @@ class AppStore {
|
||||
this.selectedPreviewModelId = modelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/instance/previews?model_id=${encodeURIComponent(modelId)}`,
|
||||
);
|
||||
let url = `/instance/previews?model_id=${encodeURIComponent(modelId)}`;
|
||||
// Add node filter if active
|
||||
if (this.previewNodeFilter.size > 0) {
|
||||
for (const nodeId of this.previewNodeFilter) {
|
||||
url += `&node_ids=${encodeURIComponent(nodeId)}`;
|
||||
}
|
||||
}
|
||||
const response = await fetch(url);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch placement previews: ${response.status}`,
|
||||
@@ -1075,6 +1102,71 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggle a node in the preview filter and re-fetch placements
|
||||
*/
|
||||
togglePreviewNodeFilter(nodeId: string) {
|
||||
const next = new Set(this.previewNodeFilter);
|
||||
if (next.has(nodeId)) {
|
||||
next.delete(nodeId);
|
||||
} else {
|
||||
next.add(nodeId);
|
||||
}
|
||||
this.previewNodeFilter = next;
|
||||
// Re-fetch with new filter if we have a selected model
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the preview node filter and re-fetch placements
|
||||
*/
|
||||
clearPreviewNodeFilter() {
|
||||
this.previewNodeFilter = new Set();
|
||||
// Re-fetch with no filter if we have a selected model
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle topology changes - clean up filter and re-fetch if needed
|
||||
*/
|
||||
private handleTopologyChange() {
|
||||
if (!this.topologyData) return;
|
||||
|
||||
const currentNodeIds = new Set(Object.keys(this.topologyData.nodes));
|
||||
|
||||
// Check if nodes have changed
|
||||
const nodesAdded = [...currentNodeIds].some(
|
||||
(id) => !this.previousNodeIds.has(id),
|
||||
);
|
||||
const nodesRemoved = [...this.previousNodeIds].some(
|
||||
(id) => !currentNodeIds.has(id),
|
||||
);
|
||||
|
||||
if (nodesAdded || nodesRemoved) {
|
||||
// Clean up filter - remove any nodes that no longer exist
|
||||
if (this.previewNodeFilter.size > 0) {
|
||||
const validFilterNodes = new Set(
|
||||
[...this.previewNodeFilter].filter((id) => currentNodeIds.has(id)),
|
||||
);
|
||||
if (validFilterNodes.size !== this.previewNodeFilter.size) {
|
||||
this.previewNodeFilter = validFilterNodes;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-fetch previews if we have a selected model (topology changed)
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
// Update tracked node IDs for next comparison
|
||||
this.previousNodeIds = currentNodeIds;
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a chat conversation - triggers the topology minimization animation
|
||||
* Creates a new conversation if none is active
|
||||
@@ -2098,6 +2190,10 @@ export const setSelectedChatModel = (modelId: string) =>
|
||||
appStore.setSelectedModel(modelId);
|
||||
export const selectPreviewModel = (modelId: string | null) =>
|
||||
appStore.selectPreviewModel(modelId);
|
||||
export const togglePreviewNodeFilter = (nodeId: string) =>
|
||||
appStore.togglePreviewNodeFilter(nodeId);
|
||||
export const clearPreviewNodeFilter = () => appStore.clearPreviewNodeFilter();
|
||||
export const previewNodeFilter = () => appStore.previewNodeFilter;
|
||||
export const deleteMessage = (messageId: string) =>
|
||||
appStore.deleteMessage(messageId);
|
||||
export const editMessage = (messageId: string, newContent: string) =>
|
||||
@@ -2134,6 +2230,10 @@ export const setChatSidebarVisible = (visible: boolean) =>
|
||||
appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
// Thunderbolt bridge status
|
||||
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
|
||||
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
|
||||
|
||||
// Image generation params
|
||||
export const imageGenerationParams = () => appStore.getImageGenerationParams();
|
||||
export const setImageGenerationParams = (
|
||||
|
||||
@@ -19,6 +19,9 @@
|
||||
selectedPreviewModelId,
|
||||
isLoadingPreviews,
|
||||
selectPreviewModel,
|
||||
togglePreviewNodeFilter,
|
||||
clearPreviewNodeFilter,
|
||||
previewNodeFilter,
|
||||
createConversation,
|
||||
setSelectedChatModel,
|
||||
selectedChatModel,
|
||||
@@ -28,6 +31,8 @@
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
thunderboltBridgeCycles,
|
||||
nodeThunderboltBridge,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview,
|
||||
} from "$lib/stores/app.svelte";
|
||||
@@ -49,6 +54,41 @@
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
const nodeFilter = $derived(previewNodeFilter());
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8) + "...";
|
||||
}
|
||||
|
||||
// Helper to get the thunderbolt bridge service name from a cycle
|
||||
function getTbBridgeServiceName(cycle: string[]): string {
|
||||
// Try to find service name from any node in the cycle
|
||||
for (const nodeId of cycle) {
|
||||
const nodeData = tbBridgeData?.[nodeId];
|
||||
if (nodeData?.serviceName) {
|
||||
return nodeData.serviceName;
|
||||
}
|
||||
}
|
||||
return "Thunderbolt Bridge"; // Fallback if no service name found
|
||||
}
|
||||
|
||||
// Copy to clipboard state and function
|
||||
let copiedCommand = $state(false);
|
||||
async function copyToClipboard(text: string) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
copiedCommand = true;
|
||||
setTimeout(() => {
|
||||
copiedCommand = false;
|
||||
}, 2000);
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
}
|
||||
}
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -90,6 +130,15 @@
|
||||
model.tasks.includes("ImageToImage")
|
||||
);
|
||||
}
|
||||
|
||||
// Helper to check if a model supports image editing
|
||||
function modelSupportsImageEditing(modelId: string): boolean {
|
||||
const model = models.find(
|
||||
(m) => m.id === modelId || m.hugging_face_id === modelId,
|
||||
);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes("ImageToImage");
|
||||
}
|
||||
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
|
||||
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
|
||||
@@ -181,6 +230,9 @@
|
||||
// Preview card hover state for highlighting nodes in topology
|
||||
let hoveredPreviewNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Computed: Check if filter is active (from store)
|
||||
const isFilterActive = $derived(() => nodeFilter.size > 0);
|
||||
|
||||
// Helper to unwrap tagged instance for hover highlighting
|
||||
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object")
|
||||
@@ -732,6 +784,8 @@
|
||||
instanceWrapped: unknown,
|
||||
): {
|
||||
isDownloading: boolean;
|
||||
isFailed: boolean;
|
||||
errorMessage: string | null;
|
||||
progress: DownloadProgress | null;
|
||||
statusText: string;
|
||||
perNode: Array<{
|
||||
@@ -743,6 +797,8 @@
|
||||
if (!downloadsData || Object.keys(downloadsData).length === 0) {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: "RUNNING",
|
||||
perNode: [],
|
||||
@@ -754,6 +810,8 @@
|
||||
if (!instance || typeof instance !== "object") {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: "PREPARING",
|
||||
perNode: [],
|
||||
@@ -809,6 +867,26 @@
|
||||
downloadKind
|
||||
] as Record<string, unknown>;
|
||||
|
||||
// Handle DownloadFailed - return immediately with error info
|
||||
if (downloadKind === "DownloadFailed") {
|
||||
const downloadModelId = extractModelIdFromDownload(downloadPayload);
|
||||
if (
|
||||
instanceModelId &&
|
||||
downloadModelId &&
|
||||
downloadModelId === instanceModelId
|
||||
) {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: true,
|
||||
errorMessage:
|
||||
(downloadPayload.errorMessage as string) || "Download failed",
|
||||
progress: null,
|
||||
statusText: "FAILED",
|
||||
perNode: [],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (downloadKind !== "DownloadOngoing") continue;
|
||||
if (!downloadPayload) continue;
|
||||
|
||||
@@ -844,6 +922,8 @@
|
||||
const statusInfo = deriveInstanceStatus(instanceWrapped);
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: statusInfo.statusText === "FAILED",
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: statusInfo.statusText,
|
||||
perNode: [],
|
||||
@@ -856,6 +936,8 @@
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
@@ -1451,7 +1533,7 @@
|
||||
|
||||
// Get ALL filtered previews based on current settings (matching minimum nodes)
|
||||
// Note: previewsData already contains previews for the selected model (fetched via API)
|
||||
// We filter by sharding/instance type and min nodes, returning ALL eligible previews
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
|
||||
const filteredPreviews = $derived(() => {
|
||||
if (!selectedModelId || previewsData.length === 0) return [];
|
||||
|
||||
@@ -1584,7 +1666,86 @@
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
@@ -1624,7 +1785,111 @@
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Node Filter Indicator (top-right corner) -->
|
||||
{#if isFilterActive()}
|
||||
<button
|
||||
onclick={clearPreviewNodeFilter}
|
||||
class="absolute top-2 right-2 flex items-center gap-1.5 px-2 py-1 bg-exo-dark-gray/80 border border-exo-yellow/40 rounded text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Clear filter"
|
||||
>
|
||||
<span class="text-[10px] font-mono tracking-wider">
|
||||
FILTER: {nodeFilter.size}
|
||||
</span>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Chat Input - Below topology -->
|
||||
@@ -2061,6 +2326,13 @@
|
||||
>
|
||||
{downloadInfo.statusText}
|
||||
</div>
|
||||
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
|
||||
<div
|
||||
class="text-xs text-red-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -2106,6 +2378,9 @@
|
||||
{@const isImageModel = modelSupportsImageGeneration(
|
||||
foundModel.id,
|
||||
)}
|
||||
{@const isImageEditModel = modelSupportsImageEditing(
|
||||
foundModel.id,
|
||||
)}
|
||||
<span
|
||||
class="flex items-center justify-between gap-2 w-full pr-4"
|
||||
>
|
||||
@@ -2132,6 +2407,22 @@
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{/if}
|
||||
{#if isImageEditModel}
|
||||
<svg
|
||||
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
|
||||
/>
|
||||
<path
|
||||
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{foundModel.name || foundModel.id}</span
|
||||
>
|
||||
@@ -2204,6 +2495,9 @@
|
||||
{@const isImageModel = modelSupportsImageGeneration(
|
||||
model.id,
|
||||
)}
|
||||
{@const isImageEditModel = modelSupportsImageEditing(
|
||||
model.id,
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
@@ -2244,6 +2538,23 @@
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{/if}
|
||||
{#if isImageEditModel}
|
||||
<svg
|
||||
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
aria-label="Image editing model"
|
||||
>
|
||||
<path
|
||||
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
|
||||
/>
|
||||
<path
|
||||
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span
|
||||
@@ -2564,7 +2875,36 @@
|
||||
<div
|
||||
class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden"
|
||||
>
|
||||
<TopologyGraph highlightedNodes={highlightedNodes()} />
|
||||
<TopologyGraph
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
title="Thunderbolt Bridge cycle detected - click to view details"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-yellow-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-yellow-200"
|
||||
>TB CYCLE</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</button>
|
||||
|
||||
@@ -2993,6 +3333,13 @@
|
||||
>
|
||||
{downloadInfo.statusText}
|
||||
</div>
|
||||
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
|
||||
<div
|
||||
class="text-xs text-red-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -172,33 +172,6 @@
|
||||
}
|
||||
|
||||
let downloadOverview = $state<NodeEntry[]>([]);
|
||||
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
|
||||
[],
|
||||
);
|
||||
|
||||
async function fetchModels() {
|
||||
try {
|
||||
const response = await fetch("/models");
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
models = data.data || [];
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch models:", error);
|
||||
}
|
||||
}
|
||||
|
||||
function getModelTotalBytes(
|
||||
modelId: string,
|
||||
downloadTotalBytes: number,
|
||||
): number {
|
||||
if (downloadTotalBytes > 0) return downloadTotalBytes;
|
||||
const model = models.find((m) => m.id === modelId);
|
||||
if (model?.storage_size_megabytes) {
|
||||
return model.storage_size_megabytes * 1024 * 1024;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
try {
|
||||
@@ -373,7 +346,6 @@
|
||||
onMount(() => {
|
||||
// Ensure we fetch at least once when visiting downloads directly
|
||||
refreshState();
|
||||
fetchModels();
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -482,7 +454,7 @@
|
||||
{#if model.status !== "completed"}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(
|
||||
getModelTotalBytes(model.modelId, model.totalBytes),
|
||||
model.totalBytes,
|
||||
)}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Literal, cast
|
||||
from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -20,7 +21,10 @@ from loguru import logger
|
||||
from exo.master.image_store import ImageStore
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_IMAGE_CACHE_DIR, EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.constants import (
|
||||
EXO_IMAGE_CACHE_DIR,
|
||||
EXO_MAX_CHUNK_SIZE,
|
||||
)
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import (
|
||||
@@ -56,8 +60,15 @@ from exo.shared.types.api import (
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -93,7 +104,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -102,7 +113,19 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||
if isinstance(chunk, TokenChunk)
|
||||
else ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -162,8 +185,12 @@ class API:
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
] = {}
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
@@ -310,11 +337,20 @@ class API:
|
||||
return placements[new_ids[0]]
|
||||
|
||||
async def get_placement_previews(
|
||||
self, model_id: ModelId
|
||||
self,
|
||||
model_id: ModelId,
|
||||
node_ids: Annotated[list[NodeId] | None, Query()] = None,
|
||||
) -> PlacementPreviewResponse:
|
||||
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
|
||||
previews: list[PlacementPreview] = []
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
|
||||
# Create filtered topology if node_ids specified
|
||||
if node_ids and len(node_ids) > 0:
|
||||
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
|
||||
else:
|
||||
topology = self.state.topology
|
||||
|
||||
if len(list(topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
@@ -327,9 +363,7 @@ class API:
|
||||
instance_combinations.extend(
|
||||
[
|
||||
(sharding, instance_meta, i)
|
||||
for i in range(
|
||||
1, len(list(self.state.topology.list_nodes())) + 1
|
||||
)
|
||||
for i in range(1, len(list(topology.list_nodes())) + 1)
|
||||
]
|
||||
)
|
||||
# TODO: PDD
|
||||
@@ -347,7 +381,7 @@ class API:
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=self.state.topology,
|
||||
topology=topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
except ValueError as exc:
|
||||
@@ -439,11 +473,13 @@ class API:
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
@@ -462,7 +498,8 @@ class API:
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
if command_id in self._chat_completion_queues:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
@@ -470,6 +507,7 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
@@ -498,11 +536,12 @@ class API:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
@@ -511,7 +550,18 @@ class API:
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
@@ -529,6 +579,7 @@ class API:
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -539,6 +590,7 @@ class API:
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
@@ -554,7 +606,19 @@ class API:
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -571,7 +635,7 @@ class API:
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
role="assistant", content=combined_text, tool_calls=tool_calls
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -729,7 +793,9 @@ class API:
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
self._image_generation_queues[command_id], recv = channel[
|
||||
ImageChunk | ErrorChunk
|
||||
]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
@@ -838,7 +904,9 @@ class API:
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
self._image_generation_queues[command_id], recv = channel[
|
||||
ImageChunk | ErrorChunk
|
||||
]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
@@ -994,7 +1062,6 @@ class API:
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
@@ -1148,27 +1215,26 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
if queue := self._image_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
queue = self._image_generation_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
if queue := self._chat_completion_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert not isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -87,12 +87,12 @@ def place_instance(
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
|
||||
@@ -197,49 +197,6 @@ def get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
cycles = cycle_digraph.get_cycles()
|
||||
expected_length = len(list(cycle_digraph.list_nodes()))
|
||||
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
|
||||
if not cycles:
|
||||
if expected_length > 1:
|
||||
logger.warning(
|
||||
f"No cycles of length {expected_length} found even though chosen subgraph contained {expected_length} nodes"
|
||||
)
|
||||
return []
|
||||
|
||||
cycle = cycles[0]
|
||||
|
||||
get_thunderbolt = False
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycle):
|
||||
get_thunderbolt = True
|
||||
|
||||
logger.debug(f"Using thunderbolt cycle: {get_thunderbolt}")
|
||||
|
||||
hosts: list[Host] = []
|
||||
for i in range(len(cycle)):
|
||||
current_node = cycle.node_ids[i]
|
||||
next_node = cycle.node_ids[(i + 1) % len(cycle)]
|
||||
|
||||
for connection in cycle_digraph.get_all_connections_between(
|
||||
source=current_node, sink=next_node
|
||||
):
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
selected_cycle: list[NodeId],
|
||||
cycle_digraph: Topology,
|
||||
@@ -265,9 +222,6 @@ def get_mlx_jaccl_devices_matrix(
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i} and {node_j}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
)
|
||||
@@ -279,22 +233,11 @@ def _find_connection_ip(
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
) -> Generator[str, None, None]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str, node_network: NodeNetworkInfo
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
for interface in node_network.interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
yield connection.sink_multiaddr.ip_address
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
@@ -303,43 +246,19 @@ def _find_ip_prioritised(
|
||||
cycle_digraph: Topology,
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
Priority: ethernet > wifi > unknown > thunderbolt
|
||||
"""
|
||||
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {
|
||||
_find_interface_name_for_ip(
|
||||
ip, node_network.get(other_node_id, NodeNetworkInfo())
|
||||
): ip
|
||||
for ip, _ in ips
|
||||
if not ips:
|
||||
return None
|
||||
other_network = node_network.get(other_node_id, NodeNetworkInfo())
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
@@ -381,9 +300,6 @@ def get_mlx_ring_hosts_by_node(
|
||||
node_id, other_node_id, cycle_digraph, node_network
|
||||
)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
@@ -416,9 +332,6 @@ def get_mlx_jaccl_coordinators(
|
||||
if ip is not None:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
)
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
@@ -48,95 +44,3 @@ def test_http_exception_handler_formats_openai_style() -> None:
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
|
||||
|
||||
def test_image_chunk_with_error_fields() -> None:
|
||||
chunk = ImageChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
data="",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
finish_reason="error",
|
||||
error_message="Image generation failed",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Image generation failed"
|
||||
assert chunk.data == ""
|
||||
assert chunk.chunk_index == 0
|
||||
assert chunk.total_chunks == 1
|
||||
assert chunk.image_index == 0
|
||||
|
||||
|
||||
def test_image_chunk_without_error() -> None:
|
||||
chunk = ImageChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
data="base64encodeddata",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
assert chunk.data == "base64encodeddata"
|
||||
|
||||
@@ -3,7 +3,6 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
@@ -14,7 +13,7 @@ from exo.master.tests.conftest import (
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
NetworkInterfaceInfo,
|
||||
@@ -273,45 +272,6 @@ def test_get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip="169.254.0.1", port=1234),
|
||||
Host(ip="169.254.0.2", port=1234),
|
||||
Host(ip="169.254.0.3", port=1234),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
|
||||
@@ -30,6 +30,7 @@ from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
@@ -46,6 +47,7 @@ from exo.utils.info_gatherer.info_gatherer import (
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
ThunderboltBridgeInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -225,6 +227,21 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
for key, value in state.node_thunderbolt.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_thunderbolt_bridge = {
|
||||
key: value
|
||||
for key, value in state.node_thunderbolt_bridge.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
# Only recompute cycles if the leaving node had TB bridge enabled
|
||||
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
leaving_node_had_tb_enabled = (
|
||||
leaving_node_status is not None and leaving_node_status.enabled
|
||||
)
|
||||
thunderbolt_bridge_cycles = (
|
||||
topology.get_thunderbolt_bridge_cycles(node_thunderbolt_bridge, node_network)
|
||||
if leaving_node_had_tb_enabled
|
||||
else [list(cycle) for cycle in state.thunderbolt_bridge_cycles]
|
||||
)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
@@ -235,6 +252,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"node_system": node_system,
|
||||
"node_network": node_network,
|
||||
"node_thunderbolt": node_thunderbolt,
|
||||
"node_thunderbolt_bridge": node_thunderbolt_bridge,
|
||||
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -312,6 +331,22 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
|
||||
case ThunderboltBridgeInfo():
|
||||
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
|
||||
**state.node_thunderbolt_bridge,
|
||||
event.node_id: info.status,
|
||||
}
|
||||
update["node_thunderbolt_bridge"] = new_tb_bridge
|
||||
# Only recompute cycles if the enabled status changed
|
||||
old_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
old_enabled = old_status.enabled if old_status else False
|
||||
new_enabled = info.status.enabled
|
||||
if old_enabled != new_enabled:
|
||||
update["thunderbolt_bridge_cycles"] = (
|
||||
topology.get_thunderbolt_bridge_cycles(
|
||||
new_tb_bridge, state.node_network
|
||||
)
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
@@ -49,3 +49,7 @@ LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
|
||||
|
||||
EXO_ENABLE_IMAGE_MODELS = (
|
||||
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -410,161 +411,166 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
|
||||
# "flux1-schnell": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23782357120),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "flux1-dev": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23802816640),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image-edit-2509": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
@@ -627,7 +633,7 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
|
||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
@@ -650,7 +656,7 @@ async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
|
||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -7,6 +7,11 @@ import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import (
|
||||
InterfaceType,
|
||||
NodeNetworkInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.topology import (
|
||||
Connection,
|
||||
Cycle,
|
||||
@@ -188,24 +193,25 @@ class Topology:
|
||||
cycles.append(Cycle(node_ids=[node_id]))
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[Cycle]:
|
||||
tb_edges = [
|
||||
def get_rdma_cycles(self) -> list[Cycle]:
|
||||
rdma_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
if isinstance(conn, RDMAConnection)
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
rdma_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = (
|
||||
rx.PyDiGraph()
|
||||
)
|
||||
rdma_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
for u, v, conn in rdma_edges:
|
||||
rdma_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycle_idxs = rx.simple_cycles(rdma_graph)
|
||||
cycles: list[Cycle] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
|
||||
cycle = Cycle(node_ids=[rdma_graph[idx] for idx in cycle_idx])
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
@@ -219,18 +225,83 @@ class Topology:
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
|
||||
def is_rdma_cycle(self, cycle: Cycle) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
continue
|
||||
has_tb = False
|
||||
has_rdma = False
|
||||
for edge in self._graph.get_all_edge_data(rid, neighbor_rid):
|
||||
if edge.is_thunderbolt():
|
||||
has_tb = True
|
||||
if isinstance(edge, RDMAConnection):
|
||||
has_rdma = True
|
||||
break
|
||||
if not has_tb:
|
||||
if not has_rdma:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_thunderbolt_bridge_cycles(
|
||||
self,
|
||||
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
) -> list[list[NodeId]]:
|
||||
"""
|
||||
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
|
||||
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
|
||||
2 or fewer nodes don't cause the broadcast storm problem.
|
||||
"""
|
||||
enabled_nodes = {
|
||||
node_id
|
||||
for node_id, status in node_tb_bridge_status.items()
|
||||
if status.enabled
|
||||
}
|
||||
|
||||
if len(enabled_nodes) < 3:
|
||||
return []
|
||||
|
||||
thunderbolt_ips = _get_ips_with_interface_type(
|
||||
enabled_nodes, node_network, "thunderbolt"
|
||||
)
|
||||
|
||||
# Build subgraph with only TB bridge enabled nodes and thunderbolt connections
|
||||
graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = rx.PyDiGraph()
|
||||
node_to_idx: dict[NodeId, int] = {}
|
||||
|
||||
for node_id in enabled_nodes:
|
||||
if node_id in self._vertex_indices:
|
||||
node_to_idx[node_id] = graph.add_node(node_id)
|
||||
|
||||
for u, v, conn in self._graph.weighted_edge_list():
|
||||
source_id, sink_id = self._graph[u], self._graph[v]
|
||||
if source_id not in node_to_idx or sink_id not in node_to_idx:
|
||||
continue
|
||||
# Include connection if it's over a thunderbolt interface
|
||||
if (
|
||||
isinstance(conn, SocketConnection)
|
||||
and conn.sink_multiaddr.ip_address in thunderbolt_ips
|
||||
):
|
||||
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
|
||||
if isinstance(conn, RDMAConnection):
|
||||
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
|
||||
|
||||
return [
|
||||
[graph[idx] for idx in cycle]
|
||||
for cycle in rx.simple_cycles(graph)
|
||||
if len(cycle) > 2
|
||||
]
|
||||
|
||||
|
||||
def _get_ips_with_interface_type(
|
||||
node_ids: set[NodeId],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
interface_type: InterfaceType,
|
||||
) -> set[str]:
|
||||
"""Get all IP addresses on interfaces of the specified type for the given nodes."""
|
||||
ips: set[str] = set()
|
||||
for node_id in node_ids:
|
||||
network_info = node_network.get(node_id, NodeNetworkInfo())
|
||||
for iface in network_info.interfaces:
|
||||
if iface.interface_type == interface_type:
|
||||
ips.add(iface.ip_address)
|
||||
return ips
|
||||
|
||||
@@ -54,6 +54,18 @@ class ChatCompletionMessageText(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
index: int | None = None
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallItem
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: (
|
||||
@@ -61,7 +73,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
@@ -8,24 +7,29 @@ from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
Token = "Token"
|
||||
Image = "Image"
|
||||
from .worker.runner_response import ToolCallItem
|
||||
|
||||
|
||||
class BaseChunk(TaggedModel):
|
||||
idx: int
|
||||
model: ModelId
|
||||
|
||||
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
error_message: str
|
||||
finish_reason: Literal["error"] = "error"
|
||||
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
@@ -63,4 +67,4 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
from typing import Literal, Self
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -48,9 +48,13 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
ecpu_usage: float = 0.0
|
||||
|
||||
|
||||
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
name: str
|
||||
ip_address: str
|
||||
interface_type: InterfaceType = "unknown"
|
||||
|
||||
|
||||
class NodeIdentity(CamelCaseModel):
|
||||
@@ -71,3 +75,11 @@ class NodeThunderboltInfo(CamelCaseModel):
|
||||
"""Thunderbolt interface identifiers for a node."""
|
||||
|
||||
interfaces: Sequence[ThunderboltIdentifier] = []
|
||||
|
||||
|
||||
class ThunderboltBridgeStatus(CamelCaseModel):
|
||||
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
|
||||
|
||||
enabled: bool
|
||||
exists: bool
|
||||
service_name: str | None = None
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
SystemPerformanceProfile,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
@@ -51,6 +52,10 @@ class State(CamelCaseModel):
|
||||
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
|
||||
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
|
||||
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
|
||||
|
||||
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
|
||||
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
|
||||
|
||||
@field_serializer("topology", mode="plain")
|
||||
def _encode_topology(self, value: Topology) -> TopologySnapshot:
|
||||
|
||||
@@ -21,9 +21,6 @@ class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
@@ -31,9 +28,6 @@ class SocketConnection(FrozenModel):
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
|
||||
class Connection(FrozenModel):
|
||||
source: NodeId
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import (
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -48,5 +53,9 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
yield name, value
|
||||
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -19,6 +19,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnection,
|
||||
@@ -34,6 +35,142 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
async def _get_thunderbolt_devices() -> set[str] | None:
|
||||
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listallhardwareports failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
output = result.stdout.decode()
|
||||
thunderbolt_devices: set[str] = set()
|
||||
current_port: str | None = None
|
||||
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("Hardware Port:"):
|
||||
current_port = line.split(":", 1)[1].strip()
|
||||
elif line.startswith("Device:") and current_port:
|
||||
device = line.split(":", 1)[1].strip()
|
||||
if "thunderbolt" in current_port.lower():
|
||||
thunderbolt_devices.add(device)
|
||||
current_port = None
|
||||
|
||||
return thunderbolt_devices
|
||||
|
||||
|
||||
async def _get_bridge_services() -> dict[str, str] | None:
|
||||
"""Get mapping of bridge device -> service name from network service order.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listnetworkserviceorder"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listnetworkserviceorder failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse service order to find bridge devices and their service names
|
||||
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
|
||||
service_order_output = result.stdout.decode()
|
||||
bridge_services: dict[str, str] = {} # device -> service name
|
||||
current_service: str | None = None
|
||||
|
||||
for line in service_order_output.splitlines():
|
||||
line = line.strip()
|
||||
# Match "(N) Service Name" or "(*) Service Name" (disabled)
|
||||
# but NOT "(Hardware Port: ...)" lines
|
||||
if (
|
||||
line
|
||||
and line.startswith("(")
|
||||
and ")" in line
|
||||
and not line.startswith("(Hardware Port:")
|
||||
):
|
||||
paren_end = line.index(")")
|
||||
if paren_end + 2 <= len(line):
|
||||
current_service = line[paren_end + 2 :]
|
||||
# Match "(Hardware Port: ..., Device: bridgeX)"
|
||||
elif current_service and "Device: bridge" in line:
|
||||
# Extract device name from "..., Device: bridge0)"
|
||||
device_start = line.find("Device: ") + len("Device: ")
|
||||
device_end = line.find(")", device_start)
|
||||
if device_end > device_start:
|
||||
device = line[device_start:device_end]
|
||||
bridge_services[device] = current_service
|
||||
|
||||
return bridge_services
|
||||
|
||||
|
||||
async def _get_bridge_members(bridge_device: str) -> set[str]:
|
||||
"""Get member interfaces of a bridge device via ifconfig."""
|
||||
result = await anyio.run_process(
|
||||
["ifconfig", bridge_device],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
|
||||
return set()
|
||||
|
||||
members: set[str] = set()
|
||||
ifconfig_output = result.stdout.decode()
|
||||
for line in ifconfig_output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("member:"):
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
members.add(parts[1])
|
||||
|
||||
return members
|
||||
|
||||
|
||||
async def _find_thunderbolt_bridge(
|
||||
bridge_services: dict[str, str], thunderbolt_devices: set[str]
|
||||
) -> str | None:
|
||||
"""Find the service name of a bridge containing Thunderbolt interfaces.
|
||||
|
||||
Returns the service name if found, None otherwise.
|
||||
"""
|
||||
for bridge_device, service_name in bridge_services.items():
|
||||
members = await _get_bridge_members(bridge_device)
|
||||
if members & thunderbolt_devices: # intersection is non-empty
|
||||
return service_name
|
||||
return None
|
||||
|
||||
|
||||
async def _is_service_enabled(service_name: str) -> bool | None:
|
||||
"""Check if a network service is enabled.
|
||||
|
||||
Returns True if enabled, False if disabled, None on error.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-getnetworkserviceenabled", service_name],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -getnetworkserviceenabled '{service_name}' "
|
||||
f"failed with code {result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
stdout = result.stdout.decode().strip().lower()
|
||||
return stdout == "enabled"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
@@ -58,6 +195,66 @@ class MacThunderboltConnections(TaggedModel):
|
||||
conns: Sequence[ThunderboltConnection]
|
||||
|
||||
|
||||
class ThunderboltBridgeInfo(TaggedModel):
|
||||
status: ThunderboltBridgeStatus
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
"""Check if a Thunderbolt Bridge network service is enabled on this node.
|
||||
|
||||
Detection approach:
|
||||
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
|
||||
2. Find bridge devices from network service order (not hardware ports, as
|
||||
bridges may not appear there)
|
||||
3. Check each bridge's members via ifconfig
|
||||
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
|
||||
5. Check if that network service is enabled
|
||||
"""
|
||||
if not IS_DARWIN:
|
||||
return None
|
||||
|
||||
def _no_bridge_status() -> Self:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=False, service_name=None
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
tb_devices = await _get_thunderbolt_devices()
|
||||
if tb_devices is None:
|
||||
return _no_bridge_status()
|
||||
|
||||
bridge_services = await _get_bridge_services()
|
||||
if not bridge_services:
|
||||
return _no_bridge_status()
|
||||
|
||||
tb_service_name = await _find_thunderbolt_bridge(
|
||||
bridge_services, tb_devices
|
||||
)
|
||||
if not tb_service_name:
|
||||
return _no_bridge_status()
|
||||
|
||||
enabled = await _is_service_enabled(tb_service_name)
|
||||
if enabled is None:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=True, service_name=tb_service_name
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=enabled,
|
||||
exists=True,
|
||||
service_name=tb_service_name,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
@@ -111,6 +308,7 @@ GatheredInfo = (
|
||||
| NodeNetworkInterfaces
|
||||
| MacThunderboltIdentifiers
|
||||
| MacThunderboltConnections
|
||||
| ThunderboltBridgeInfo
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
@@ -125,6 +323,7 @@ class InfoGatherer:
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
@@ -133,6 +332,7 @@ class InfoGatherer:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
|
||||
tg.start_soon(self._monitor_thunderbolt_bridge_status)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
@@ -206,6 +406,17 @@ class InfoGatherer:
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import socket
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from subprocess import CalledProcessError, run
|
||||
|
||||
import psutil
|
||||
from anyio import run_process
|
||||
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo
|
||||
|
||||
|
||||
async def get_friendly_name() -> str:
|
||||
@@ -28,6 +28,38 @@ async def get_friendly_name() -> str:
|
||||
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
|
||||
|
||||
|
||||
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
"""Parse networksetup -listallhardwareports to get interface types."""
|
||||
if sys.platform != "darwin":
|
||||
return {}
|
||||
try:
|
||||
result = run(
|
||||
["networksetup", "-listallhardwareports"], capture_output=True, text=True
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
types: dict[str, InterfaceType] = {}
|
||||
current_type: InterfaceType = "unknown"
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if line.startswith("Hardware Port:"):
|
||||
port_name = line.split(":", 1)[1].strip()
|
||||
if "Wi-Fi" in port_name:
|
||||
current_type = "wifi"
|
||||
elif "Ethernet" in port_name or "LAN" in port_name:
|
||||
current_type = "ethernet"
|
||||
elif port_name.startswith("Thunderbolt"):
|
||||
current_type = "thunderbolt"
|
||||
else:
|
||||
current_type = "unknown"
|
||||
elif line.startswith("Device:"):
|
||||
device = line.split(":", 1)[1].strip()
|
||||
types[device] = current_type
|
||||
|
||||
return types
|
||||
|
||||
|
||||
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
"""
|
||||
Retrieves detailed network interface information on macOS.
|
||||
@@ -36,13 +68,18 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
Returns a list of NetworkInterfaceInfo objects.
|
||||
"""
|
||||
interfaces_info: list[NetworkInterfaceInfo] = []
|
||||
interface_types = _get_interface_types_from_networksetup()
|
||||
|
||||
for iface, services in psutil.net_if_addrs().items():
|
||||
for service in services:
|
||||
match service.family:
|
||||
case socket.AF_INET | socket.AF_INET6:
|
||||
interfaces_info.append(
|
||||
NetworkInterfaceInfo(name=iface, ip_address=service.address)
|
||||
NetworkInterfaceInfo(
|
||||
name=iface,
|
||||
ip_address=service.address,
|
||||
interface_type=interface_types.get(iface, "unknown"),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
|
||||
@@ -40,9 +40,31 @@ from exo.worker.download.huggingface_utils import (
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceAuthenticationError(Exception):
|
||||
"""Raised when HuggingFace returns 401/403 for a model download."""
|
||||
|
||||
|
||||
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
|
||||
token = await get_hf_token()
|
||||
if status_code == 401 and token is None:
|
||||
return (
|
||||
f"Model '{model_id}' requires authentication. "
|
||||
f"Set HF_TOKEN in the app's Advanced settings, set the HF_TOKEN environment variable, or run `hf auth login`. "
|
||||
f"Get a token at https://huggingface.co/settings/tokens"
|
||||
)
|
||||
elif status_code == 403:
|
||||
return (
|
||||
f"Access denied to '{model_id}'. "
|
||||
f"Please accept the model terms at https://huggingface.co/{model_id}"
|
||||
)
|
||||
else:
|
||||
return f"Authentication failed for '{model_id}' (HTTP {status_code})"
|
||||
|
||||
|
||||
def trim_etag(etag: str) -> str:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
return etag[1:-1]
|
||||
@@ -147,6 +169,8 @@ async def fetch_file_list_with_retry(
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
@@ -167,6 +191,9 @@ async def _fetch_file_list(
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.get(url, headers=headers) as response,
|
||||
):
|
||||
if response.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
if response.status == 200:
|
||||
data_json = await response.text()
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
||||
@@ -256,6 +283,9 @@ async def file_meta(
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
@@ -279,6 +309,8 @@ async def download_file_with_retry(
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
raise e
|
||||
@@ -322,6 +354,9 @@ async def _download_file(
|
||||
):
|
||||
if r.status == 404:
|
||||
raise FileNotFoundError(f"File not found: {url}")
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
assert r.status in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status}"
|
||||
)
|
||||
@@ -463,7 +498,7 @@ async def download_shard(
|
||||
allow_patterns: list[str] | None = None,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=}")
|
||||
logger.debug(f"Downloading {shard.model_card.model_id=}")
|
||||
|
||||
revision = "main"
|
||||
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
|
||||
@@ -476,7 +511,7 @@ async def download_shard(
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
logger.debug(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
|
||||
@@ -68,7 +68,11 @@ def get_hf_home() -> Path:
|
||||
|
||||
|
||||
async def get_hf_token() -> str | None:
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
"""Retrieve the Hugging Face token from HF_TOKEN env var or HF_HOME directory."""
|
||||
# Check environment variable first
|
||||
if token := os.environ.get("HF_TOKEN"):
|
||||
return token
|
||||
# Fall back to file-based token
|
||||
token_path = get_hf_home() / "token"
|
||||
if await aios.path.exists(token_path):
|
||||
async with aiofiles.open(token_path, "r") as f:
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
@@ -19,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -166,7 +168,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
print("Error downloading shard:", e)
|
||||
logger.error("Error downloading shard:", e)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -145,6 +145,10 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -252,10 +256,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
|
||||
@@ -170,10 +170,10 @@ def mlx_distributed_init(
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
"""
|
||||
Normalize tool_calls in a message dict.
|
||||
|
||||
OpenAI format has tool_calls[].function.arguments as a JSON string,
|
||||
but some chat templates (e.g., GLM) expect it as a dict.
|
||||
"""
|
||||
tool_calls = msg_dict.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
return
|
||||
|
||||
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if not isinstance(func, dict):
|
||||
continue
|
||||
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if isinstance(args, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
chat_task_data: ChatCompletionTaskParams,
|
||||
) -> str:
|
||||
# Now we can properly access the messages
|
||||
messages = chat_task_data.messages
|
||||
tools = chat_task_data.tools
|
||||
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
@@ -386,15 +409,19 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
formatted_messages.append(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
dumped: dict[str, Any] = message.model_dump()
|
||||
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
|
||||
|
||||
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
|
||||
_normalize_tool_calls(msg_dict)
|
||||
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
@@ -38,6 +38,7 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
@@ -443,7 +444,33 @@ class Worker:
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def download_with_error_handling() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(task.shard_metadata)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(
|
||||
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
|
||||
)
|
||||
failed_status = DownloadFailed(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
error_message=error_message,
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = (
|
||||
failed_status
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed_status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Failed
|
||||
)
|
||||
)
|
||||
|
||||
self._tg.start_soon(download_with_error_handling)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
|
||||
@@ -20,6 +20,7 @@ from exo.shared.types.tasks import (
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
@@ -122,7 +123,8 @@ def _model_needs_download(
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
download_status[model_id],
|
||||
(DownloadOngoing, DownloadCompleted, DownloadFailed),
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Literal
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -13,11 +14,12 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -42,6 +44,8 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
ToolCallItem,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -154,6 +158,9 @@ def main(
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
@@ -244,17 +251,53 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
if "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0:
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
@@ -263,6 +306,17 @@ def main(
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
@@ -270,11 +324,8 @@ def main(
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
@@ -328,18 +379,14 @@ def main(
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=0,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
data="",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
@@ -396,13 +443,8 @@ def main(
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=0,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
data="",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
@@ -446,6 +488,18 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
@@ -526,7 +580,6 @@ def _send_image_chunk(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
@@ -568,6 +621,196 @@ def _process_image_response(
|
||||
)
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if in_tool_call and response.text == tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if isinstance(parsed, list):
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
|
||||
)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
continue
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Version of to-be-upstreamed kimi-k2 tool parser
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
|
||||
tool_call_start = "<|tool_call_begin|>"
|
||||
tool_call_end = "<|tool_call_end|>"
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = func_name[func_name.find(".") + 1 :]
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
return ToolCallItem(name=name, arguments=json.dumps(args))
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
@@ -140,6 +140,13 @@ class EventCollector:
|
||||
pass
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -171,7 +178,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
expected_chunk = ChunkGenerated(
|
||||
command_id=COMMAND_1_ID,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="hi",
|
||||
token_id=0,
|
||||
|
||||
59
uv.lock
generated
59
uv.lock
generated
@@ -376,8 +376,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -413,7 +413,7 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = ">=0.14.2" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
|
||||
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
@@ -458,16 +458,6 @@ dev = [
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
|
||||
]
|
||||
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
@@ -1004,8 +994,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1032,18 +1022,12 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
|
||||
]
|
||||
@@ -1056,6 +1040,14 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.4.dev20260121+fbe306f9"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.3"
|
||||
@@ -1086,7 +1078,7 @@ version = "0.30.4"
|
||||
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1094,16 +1086,6 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
version = "10.8.0"
|
||||
@@ -2227,6 +2209,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.1"
|
||||
|
||||
Reference in New Issue
Block a user