mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-20 20:10:10 -05:00
Compare commits
20 Commits
foo
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56ae1e22e3 | ||
|
|
215132c450 | ||
|
|
82040e7e77 | ||
|
|
86fb761c69 | ||
|
|
08252a81c4 | ||
|
|
3b62fa9c1f | ||
|
|
ad7f5b746b | ||
|
|
bcec2bd5b0 | ||
|
|
6ca54112c0 | ||
|
|
9267809d2b | ||
|
|
9cd2dd48aa | ||
|
|
cadd66c1a1 | ||
|
|
fb90e5dc1b | ||
|
|
844b8f6739 | ||
|
|
069be5fa23 | ||
|
|
fcc92e5d5e | ||
|
|
63b74309f5 | ||
|
|
758464703d | ||
|
|
9e2179c848 | ||
|
|
22b5d836ef |
2
.github/workflows/build-app.yml
vendored
2
.github/workflows/build-app.yml
vendored
@@ -315,7 +315,7 @@ jobs:
|
||||
--apple-id "$APPLE_NOTARIZATION_USERNAME" \
|
||||
--password "$APPLE_NOTARIZATION_PASSWORD" \
|
||||
--team-id "$APPLE_NOTARIZATION_TEAM" \
|
||||
--wait --timeout 15m 2>&1)
|
||||
--wait --timeout 20m 2>&1)
|
||||
echo "$SUBMISSION_OUTPUT"
|
||||
|
||||
SUBMISSION_ID=$(echo "$SUBMISSION_OUTPUT" | awk 'tolower($1)=="id:" && $2 ~ /^[0-9a-fA-F-]+$/ {print $2; exit}')
|
||||
|
||||
@@ -14,6 +14,7 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@State private var showAllNodes = false
|
||||
|
||||
@@ -21,6 +21,7 @@ struct EXOApp: App {
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
|
||||
@@ -41,10 +42,13 @@ struct EXOApp: App {
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
|
||||
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
// Remove old LaunchDaemon components if they exist (from previous versions)
|
||||
cleanupLegacyNetworkSetup()
|
||||
// Check local network access periodically (warning disappears when user grants permission)
|
||||
localNetwork.startPeriodicChecking(interval: 10)
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -58,6 +62,7 @@ struct EXOApp: App {
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
.environmentObject(thunderboltBridgeService)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
}
|
||||
@@ -130,6 +135,20 @@ struct EXOApp: App {
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
private func cleanupLegacyNetworkSetup() {
|
||||
guard NetworkSetupHelper.hasInstalledComponents() else { return }
|
||||
do {
|
||||
try NetworkSetupHelper.uninstall()
|
||||
Logger().info("Cleaned up legacy network setup components")
|
||||
} catch {
|
||||
// Non-fatal: user may have cancelled admin prompt or cleanup may have
|
||||
// partially succeeded. The app will continue normally.
|
||||
Logger().warning(
|
||||
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
|
||||
@@ -5,17 +5,43 @@ import Foundation
|
||||
struct ClusterState: Decodable {
|
||||
let instances: [String: ClusterInstance]
|
||||
let runners: [String: RunnerStatusSummary]
|
||||
let nodeProfiles: [String: NodeProfile]
|
||||
let tasks: [String: ClusterTask]
|
||||
let topology: Topology?
|
||||
let downloads: [String: [NodeDownloadStatus]]
|
||||
let thunderboltBridgeCycles: [[String]]
|
||||
|
||||
// Granular node state (split from the old nodeProfiles)
|
||||
let nodeIdentities: [String: NodeIdentity]
|
||||
let nodeMemory: [String: MemoryInfo]
|
||||
let nodeSystem: [String: SystemInfo]
|
||||
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
|
||||
|
||||
/// Computed property for backwards compatibility - merges granular state into NodeProfile
|
||||
var nodeProfiles: [String: NodeProfile] {
|
||||
var profiles: [String: NodeProfile] = [:]
|
||||
let allNodeIds = Set(nodeIdentities.keys)
|
||||
.union(nodeMemory.keys)
|
||||
.union(nodeSystem.keys)
|
||||
for nodeId in allNodeIds {
|
||||
let identity = nodeIdentities[nodeId]
|
||||
let memory = nodeMemory[nodeId]
|
||||
let system = nodeSystem[nodeId]
|
||||
profiles[nodeId] = NodeProfile(
|
||||
modelId: identity?.modelId,
|
||||
chipId: identity?.chipId,
|
||||
friendlyName: identity?.friendlyName,
|
||||
memory: memory,
|
||||
system: system
|
||||
)
|
||||
}
|
||||
return profiles
|
||||
}
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
@@ -24,15 +50,34 @@ struct ClusterState: Decodable {
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
self.thunderboltBridgeCycles =
|
||||
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
|
||||
|
||||
// Granular node state
|
||||
self.nodeIdentities =
|
||||
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
|
||||
?? [:]
|
||||
self.nodeMemory =
|
||||
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
|
||||
self.nodeSystem =
|
||||
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
|
||||
self.nodeThunderboltBridge =
|
||||
try container.decodeIfPresent(
|
||||
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
|
||||
) ?? [:]
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case instances
|
||||
case runners
|
||||
case nodeProfiles
|
||||
case topology
|
||||
case tasks
|
||||
case downloads
|
||||
case thunderboltBridgeCycles
|
||||
case nodeIdentities
|
||||
case nodeMemory
|
||||
case nodeSystem
|
||||
case nodeThunderboltBridge
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +147,18 @@ struct NodeProfile: Decodable {
|
||||
let system: SystemInfo?
|
||||
}
|
||||
|
||||
struct NodeIdentity: Decodable {
|
||||
let modelId: String?
|
||||
let chipId: String?
|
||||
let friendlyName: String?
|
||||
}
|
||||
|
||||
struct ThunderboltBridgeStatus: Decodable {
|
||||
let enabled: Bool
|
||||
let exists: Bool
|
||||
let serviceName: String?
|
||||
}
|
||||
|
||||
struct MemoryInfo: Decodable {
|
||||
let ramTotal: MemoryValue?
|
||||
let ramAvailable: MemoryValue?
|
||||
@@ -120,16 +177,51 @@ struct SystemInfo: Decodable {
|
||||
}
|
||||
|
||||
struct Topology: Decodable {
|
||||
let nodes: [TopologyNode]
|
||||
let connections: [TopologyConnection]?
|
||||
/// Node IDs in the topology
|
||||
let nodes: [String]
|
||||
/// Flattened list of connections (source -> sink pairs)
|
||||
let connections: [TopologyConnection]
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
|
||||
|
||||
// Connections come as nested map: { source: { sink: [edges] } }
|
||||
// We flatten to array of (source, sink) pairs
|
||||
var flatConnections: [TopologyConnection] = []
|
||||
if let nested = try container.decodeIfPresent(
|
||||
[String: [String: [AnyCodable]]].self, forKey: .connections
|
||||
) {
|
||||
for (source, sinks) in nested {
|
||||
for sink in sinks.keys {
|
||||
flatConnections.append(
|
||||
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
|
||||
}
|
||||
}
|
||||
}
|
||||
self.connections = flatConnections
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case nodes
|
||||
case connections
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyNode: Decodable {
|
||||
let nodeId: String
|
||||
let nodeProfile: NodeProfile
|
||||
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
|
||||
private struct AnyCodable: Decodable {
|
||||
init(from decoder: Decoder) throws {
|
||||
// Just consume the value without storing it
|
||||
_ = try? decoder.singleValueContainer().decode(Bool.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Int.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Double.self)
|
||||
_ = try? decoder.singleValueContainer().decode(String.self)
|
||||
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
|
||||
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyConnection: Decodable {
|
||||
struct TopologyConnection {
|
||||
let localNodeId: String
|
||||
let sendBackNodeId: String
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
private var periodicTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
@@ -48,10 +49,39 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
/// Checks if local network access is working (one-time check).
|
||||
func check() {
|
||||
performCheck()
|
||||
}
|
||||
|
||||
/// Starts periodic checking of local network access.
|
||||
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
|
||||
func startPeriodicChecking(interval: TimeInterval = 10) {
|
||||
stopPeriodicChecking()
|
||||
// Do an immediate check first
|
||||
performCheck()
|
||||
// Then schedule periodic checks
|
||||
periodicTask = Task { [weak self] in
|
||||
while !Task.isCancelled {
|
||||
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
|
||||
guard !Task.isCancelled else { break }
|
||||
self?.performCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops periodic checking.
|
||||
func stopPeriodicChecking() {
|
||||
periodicTask?.cancel()
|
||||
periodicTask = nil
|
||||
}
|
||||
|
||||
private func performCheck() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
// Only show "checking" status on first check to avoid UI flicker
|
||||
if status == .unknown {
|
||||
status = .checking
|
||||
}
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
@@ -60,12 +90,15 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
// Only log on state changes or first check to reduce noise
|
||||
if isFirstCheck || result != self.status {
|
||||
Self.logger.info("Local network check: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +174,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
func stop() {
|
||||
stopPeriodicChecking()
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
|
||||
@@ -7,48 +7,10 @@ enum NetworkSetupHelper {
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge.sh"
|
||||
// Legacy script path from older versions
|
||||
private static let legacyScriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
do {
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
}
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
@@ -63,8 +25,9 @@ enum NetworkSetupHelper {
|
||||
static func hasInstalledComponents() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
return scriptExists || plistExists
|
||||
return scriptExists || legacyScriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
@@ -73,6 +36,7 @@ enum NetworkSetupHelper {
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
@@ -83,8 +47,9 @@ enum NetworkSetupHelper {
|
||||
# Remove LaunchDaemon plist
|
||||
rm -f "$PLIST_DEST"
|
||||
|
||||
# Remove the script and parent directory if empty
|
||||
# Remove the script (current and legacy paths) and parent directory if empty
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rm -f "$LEGACY_SCRIPT_DEST"
|
||||
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
|
||||
|
||||
# Remove log files
|
||||
@@ -107,90 +72,6 @@ enum NetworkSetupHelper {
|
||||
"""
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
|
||||
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let interval = plist["StartInterval"] as? Int,
|
||||
interval == requiredStartInterval
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private static func installLaunchDaemon() async throws {
|
||||
let installerScript = makeInstallerScript()
|
||||
try runShellAsAdmin(installerScript)
|
||||
}
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript =
|
||||
script
|
||||
|
||||
252
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
252
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
@@ -0,0 +1,252 @@
|
||||
import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
import Security
|
||||
import SystemConfiguration
|
||||
import os.log
|
||||
|
||||
@MainActor
|
||||
final class ThunderboltBridgeService: ObservableObject {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
|
||||
|
||||
@Published private(set) var detectedCycle: [String]?
|
||||
@Published private(set) var hasPromptedForCurrentCycle = false
|
||||
@Published private(set) var lastError: String?
|
||||
|
||||
private weak var clusterStateService: ClusterStateService?
|
||||
private var cancellables = Set<AnyCancellable>()
|
||||
private var previousCycleSignature: String?
|
||||
|
||||
init(clusterStateService: ClusterStateService) {
|
||||
self.clusterStateService = clusterStateService
|
||||
setupObserver()
|
||||
}
|
||||
|
||||
private func setupObserver() {
|
||||
guard let service = clusterStateService else { return }
|
||||
|
||||
service.$latestSnapshot
|
||||
.compactMap { $0 }
|
||||
.sink { [weak self] snapshot in
|
||||
self?.checkForCycles(snapshot: snapshot)
|
||||
}
|
||||
.store(in: &cancellables)
|
||||
}
|
||||
|
||||
private func checkForCycles(snapshot: ClusterState) {
|
||||
let cycles = snapshot.thunderboltBridgeCycles
|
||||
|
||||
// Only consider cycles with more than 2 nodes
|
||||
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
|
||||
// No problematic cycles detected, reset state
|
||||
if detectedCycle != nil {
|
||||
detectedCycle = nil
|
||||
previousCycleSignature = nil
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Create a signature for this cycle to detect if it changed
|
||||
let cycleSignature = firstCycle.sorted().joined(separator: ",")
|
||||
|
||||
// If this is a new/different cycle, reset the prompt state
|
||||
if cycleSignature != previousCycleSignature {
|
||||
previousCycleSignature = cycleSignature
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
|
||||
detectedCycle = firstCycle
|
||||
|
||||
// Only prompt once per cycle
|
||||
if !hasPromptedForCurrentCycle {
|
||||
showDisableBridgePrompt(nodeIds: firstCycle)
|
||||
}
|
||||
}
|
||||
|
||||
private func showDisableBridgePrompt(nodeIds: [String]) {
|
||||
hasPromptedForCurrentCycle = true
|
||||
|
||||
// Get friendly names for the nodes if available
|
||||
let nodeNames = nodeIds.map { nodeId -> String in
|
||||
if let snapshot = clusterStateService?.latestSnapshot,
|
||||
let profile = snapshot.nodeProfiles[nodeId],
|
||||
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
|
||||
{
|
||||
return friendlyName
|
||||
}
|
||||
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
|
||||
}
|
||||
let machineNames = nodeNames.joined(separator: ", ")
|
||||
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Thunderbolt Bridge Loop Detected"
|
||||
alert.informativeText = """
|
||||
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
|
||||
|
||||
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Disable Bridge")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
|
||||
let response = alert.runModal()
|
||||
|
||||
if response == .alertFirstButtonReturn {
|
||||
Task {
|
||||
await disableThunderboltBridge()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func disableThunderboltBridge() async {
|
||||
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
|
||||
lastError = nil
|
||||
|
||||
do {
|
||||
try await disableThunderboltBridgeWithSCPreferences()
|
||||
Self.logger.info("Successfully disabled Thunderbolt Bridge")
|
||||
} catch {
|
||||
Self.logger.error(
|
||||
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
lastError = error.localizedDescription
|
||||
showErrorAlert(message: error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
private func disableThunderboltBridgeWithSCPreferences() async throws {
|
||||
// 1. Create authorization reference
|
||||
var authRef: AuthorizationRef?
|
||||
var status = AuthorizationCreate(nil, nil, [], &authRef)
|
||||
guard status == errAuthorizationSuccess, let authRef = authRef else {
|
||||
throw ThunderboltBridgeError.authorizationFailed
|
||||
}
|
||||
|
||||
defer { AuthorizationFree(authRef, [.destroyRights]) }
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
}
|
||||
throw ThunderboltBridgeError.authorizationDenied
|
||||
}
|
||||
|
||||
// 3. Create SCPreferences with authorization
|
||||
guard
|
||||
let prefs = SCPreferencesCreateWithAuthorization(
|
||||
kCFAllocatorDefault,
|
||||
"EXO" as CFString,
|
||||
nil,
|
||||
authRef
|
||||
)
|
||||
else {
|
||||
throw ThunderboltBridgeError.preferencesCreationFailed
|
||||
}
|
||||
|
||||
// 4. Lock, modify, commit
|
||||
guard SCPreferencesLock(prefs, true) else {
|
||||
throw ThunderboltBridgeError.lockFailed
|
||||
}
|
||||
|
||||
defer {
|
||||
SCPreferencesUnlock(prefs)
|
||||
}
|
||||
|
||||
// 5. Find and disable Thunderbolt Bridge service
|
||||
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
|
||||
throw ThunderboltBridgeError.servicesNotFound
|
||||
}
|
||||
|
||||
var found = false
|
||||
for service in allServices {
|
||||
if let name = SCNetworkServiceGetName(service) as String?,
|
||||
name == "Thunderbolt Bridge"
|
||||
{
|
||||
guard SCNetworkServiceSetEnabled(service, false) else {
|
||||
throw ThunderboltBridgeError.disableFailed
|
||||
}
|
||||
found = true
|
||||
Self.logger.info("Found and disabled Thunderbolt Bridge service")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
throw ThunderboltBridgeError.serviceNotFound
|
||||
}
|
||||
|
||||
// 6. Commit and apply
|
||||
guard SCPreferencesCommitChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.commitFailed
|
||||
}
|
||||
|
||||
guard SCPreferencesApplyChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.applyFailed
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Failed to Disable Thunderbolt Bridge"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
}
|
||||
|
||||
enum ThunderboltBridgeError: LocalizedError {
|
||||
case authorizationFailed
|
||||
case authorizationCanceled
|
||||
case authorizationDenied
|
||||
case preferencesCreationFailed
|
||||
case lockFailed
|
||||
case servicesNotFound
|
||||
case serviceNotFound
|
||||
case disableFailed
|
||||
case commitFailed
|
||||
case applyFailed
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .authorizationFailed:
|
||||
return "Failed to create authorization"
|
||||
case .authorizationCanceled:
|
||||
return "Authorization was canceled by user"
|
||||
case .authorizationDenied:
|
||||
return "Authorization was denied"
|
||||
case .preferencesCreationFailed:
|
||||
return "Failed to access network preferences"
|
||||
case .lockFailed:
|
||||
return "Failed to lock network preferences for modification"
|
||||
case .servicesNotFound:
|
||||
return "Could not retrieve network services"
|
||||
case .serviceNotFound:
|
||||
return "Thunderbolt Bridge service not found"
|
||||
case .disableFailed:
|
||||
return "Failed to disable Thunderbolt Bridge service"
|
||||
case .commitFailed:
|
||||
return "Failed to save network configuration changes"
|
||||
case .applyFailed:
|
||||
return "Failed to apply network configuration changes"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ struct TopologyViewModel {
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let topologyNodeIds = Set(topology?.nodes ?? [])
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
@@ -95,8 +95,8 @@ extension ClusterState {
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
var orderedNodes: [NodeViewModel] = []
|
||||
if let topologyNodes = topology?.nodes {
|
||||
for topoNode in topologyNodes {
|
||||
if let viewModel = nodesById[topoNode.nodeId] {
|
||||
for nodeId in topologyNodes {
|
||||
if let viewModel = nodesById[nodeId] {
|
||||
orderedNodes.append(viewModel)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ extension ClusterState {
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
topology?.connections.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
|
||||
import { topologyData, isTopologyMinimized, debugMode, nodeThunderboltBridge, type NodeInfo } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -16,6 +16,7 @@ import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$li
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -932,6 +933,25 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('fill', 'rgba(179,179,179,0.7)')
|
||||
.text(` (${ramUsagePercent.toFixed(0)}%)`);
|
||||
}
|
||||
|
||||
// Debug mode: Show TB bridge status
|
||||
if (debugEnabled) {
|
||||
const tbStatus = tbBridgeData[nodeInfo.id];
|
||||
if (tbStatus) {
|
||||
const tbY = nodeInfo.y + iconBaseHeight / 2 + (showFullLabels ? 32 : showCompactLabels ? 26 : 22);
|
||||
const tbFontSize = showFullLabels ? 9 : 7;
|
||||
const tbColor = tbStatus.enabled ? 'rgba(234,179,8,0.9)' : 'rgba(100,100,100,0.7)';
|
||||
const tbText = tbStatus.enabled ? 'TB:ON' : 'TB:OFF';
|
||||
nodeG.append('text')
|
||||
.attr('x', nodeInfo.x)
|
||||
.attr('y', tbY)
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('fill', tbColor)
|
||||
.attr('font-size', tbFontSize)
|
||||
.attr('font-family', 'SF Mono, Monaco, monospace')
|
||||
.text(tbText);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@@ -173,6 +173,12 @@ export interface PlacementPreviewResponse {
|
||||
previews: PlacementPreview[];
|
||||
}
|
||||
|
||||
interface RawThunderboltBridgeStatus {
|
||||
enabled: boolean;
|
||||
exists: boolean;
|
||||
serviceName?: string | null;
|
||||
}
|
||||
|
||||
interface RawStateResponse {
|
||||
topology?: RawTopology;
|
||||
instances?: Record<
|
||||
@@ -190,6 +196,10 @@ interface RawStateResponse {
|
||||
nodeMemory?: Record<string, RawMemoryUsage>;
|
||||
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
|
||||
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
|
||||
// Thunderbolt bridge status per node
|
||||
nodeThunderboltBridge?: Record<string, RawThunderboltBridgeStatus>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -387,6 +397,13 @@ class AppStore {
|
||||
selectedPreviewModelId = $state<string | null>(null);
|
||||
isLoadingPreviews = $state(false);
|
||||
lastUpdate = $state<number | null>(null);
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderboltBridge = $state<
|
||||
Record<
|
||||
string,
|
||||
{ enabled: boolean; exists: boolean; serviceName?: string | null }
|
||||
>
|
||||
>({});
|
||||
|
||||
// UI state
|
||||
isTopologyMinimized = $state(false);
|
||||
@@ -915,6 +932,16 @@ class AppStore {
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
if (data.thunderboltBridgeCycles) {
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles;
|
||||
} else {
|
||||
this.thunderboltBridgeCycles = [];
|
||||
}
|
||||
if (data.nodeThunderboltBridge) {
|
||||
this.nodeThunderboltBridge = data.nodeThunderboltBridge;
|
||||
} else {
|
||||
this.nodeThunderboltBridge = {};
|
||||
}
|
||||
this.lastUpdate = Date.now();
|
||||
} catch (error) {
|
||||
console.error("Error fetching state:", error);
|
||||
@@ -1620,6 +1647,8 @@ export const placementPreviews = () => appStore.placementPreviews;
|
||||
export const selectedPreviewModelId = () => appStore.selectedPreviewModelId;
|
||||
export const isLoadingPreviews = () => appStore.isLoadingPreviews;
|
||||
export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
|
||||
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
<script lang="ts">
|
||||
import { TopologyGraph, ChatForm, ChatMessages, ChatSidebar, ModelCard } from '$lib/components';
|
||||
import {
|
||||
hasStartedChat,
|
||||
isTopologyMinimized,
|
||||
topologyData,
|
||||
import {
|
||||
hasStartedChat,
|
||||
isTopologyMinimized,
|
||||
topologyData,
|
||||
lastUpdate,
|
||||
clearChat,
|
||||
instances,
|
||||
@@ -22,6 +22,8 @@
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
thunderboltBridgeCycles,
|
||||
nodeThunderboltBridge,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview
|
||||
} from '$lib/stores/app.svelte';
|
||||
@@ -43,6 +45,25 @@
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
}
|
||||
|
||||
// Helper to get TB bridge service name from any node in the cycle
|
||||
function getTbBridgeServiceName(cycleNodes: string[]): string {
|
||||
for (const nodeId of cycleNodes) {
|
||||
const status = tbBridgeData[nodeId];
|
||||
if (status?.serviceName) {
|
||||
return status.serviceName;
|
||||
}
|
||||
}
|
||||
return 'Thunderbolt Bridge'; // fallback
|
||||
}
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -128,7 +149,10 @@ let chatScrollRef: HTMLDivElement | null = $state(null);
|
||||
|
||||
// Preview card hover state for highlighting nodes in topology
|
||||
let hoveredPreviewNodes = $state<Set<string>>(new Set());
|
||||
|
||||
|
||||
// Thunderbolt bridge cycle hover state for highlighting nodes in topology
|
||||
let hoveredBridgeCycleNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Helper to unwrap tagged instance for hover highlighting
|
||||
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== 'object') return new Set();
|
||||
@@ -151,7 +175,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
instanceDownloadExpandedNodes = next;
|
||||
}
|
||||
|
||||
// Compute highlighted nodes from hovered instance or hovered preview
|
||||
// Compute highlighted nodes from hovered instance, preview, or bridge cycle
|
||||
const highlightedNodes = $derived(() => {
|
||||
// First check instance hover
|
||||
if (hoveredInstanceId) {
|
||||
@@ -162,6 +186,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
if (hoveredPreviewNodes.size > 0) {
|
||||
return hoveredPreviewNodes;
|
||||
}
|
||||
// Then check bridge cycle hover
|
||||
if (hoveredBridgeCycleNodes.size > 0) {
|
||||
return hoveredBridgeCycleNodes;
|
||||
}
|
||||
return new Set<string>();
|
||||
});
|
||||
|
||||
@@ -1233,6 +1261,58 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
|
||||
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
|
||||
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div
|
||||
class="absolute top-4 left-4 group"
|
||||
role="alert"
|
||||
onmouseenter={() => hoveredBridgeCycleNodes = new Set(cycle)}
|
||||
onmouseleave={() => hoveredBridgeCycleNodes = new Set()}
|
||||
>
|
||||
<div class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-default">
|
||||
<svg class="w-4 h-4 text-yellow-400 flex-shrink-0" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||
</svg>
|
||||
<span class="text-xs font-mono text-yellow-400 tracking-wider">THUNDERBOLT BRIDGE CYCLE</span>
|
||||
</div>
|
||||
<div class="absolute top-full left-0 mt-2 w-96 opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50">
|
||||
<div class="p-4 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm shadow-xl">
|
||||
<div class="text-xs text-yellow-400 font-mono mb-3">
|
||||
{cycle.length} machines detected in a Thunderbolt Bridge loop:
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1.5 mb-4">
|
||||
{#each cycle as nodeId}
|
||||
<span class="px-2 py-1 text-xs font-mono text-yellow-300 bg-yellow-500/20 rounded border border-yellow-500/30">
|
||||
{getNodeName(nodeId)}
|
||||
</span>
|
||||
{/each}
|
||||
</div>
|
||||
<div class="text-xs text-white/60 mb-2">
|
||||
Run on each machine to disable bridge:
|
||||
</div>
|
||||
<div class="relative group/cmd">
|
||||
<pre class="text-xs font-mono text-white/90 bg-black/40 p-2.5 rounded border border-white/10 overflow-x-auto"><code>{disableCmd}</code></pre>
|
||||
<button
|
||||
type="button"
|
||||
class="absolute top-1.5 right-1.5 p-1 rounded bg-white/10 hover:bg-white/20 transition-colors opacity-0 group-hover/cmd:opacity-100"
|
||||
onclick={() => navigator.clipboard.writeText(disableCmd)}
|
||||
title="Copy command"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5 text-white/70" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
|
||||
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
@@ -1258,11 +1338,65 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
<!-- Topology Container - Takes most of the space -->
|
||||
<div class="flex-1 relative bg-exo-dark-gray/40 mx-4 mb-4 rounded-lg overflow-hidden">
|
||||
|
||||
|
||||
<!-- The main topology graph - full container -->
|
||||
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div
|
||||
class="absolute top-4 left-4 group"
|
||||
role="alert"
|
||||
onmouseenter={() => hoveredBridgeCycleNodes = new Set(cycle)}
|
||||
onmouseleave={() => hoveredBridgeCycleNodes = new Set()}
|
||||
>
|
||||
<!-- Warning Badge -->
|
||||
<div class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-default">
|
||||
<svg class="w-4 h-4 text-yellow-400 flex-shrink-0" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||
</svg>
|
||||
<span class="text-xs font-mono text-yellow-400 tracking-wider">THUNDERBOLT BRIDGE CYCLE</span>
|
||||
</div>
|
||||
|
||||
<!-- Hover Tooltip -->
|
||||
<div class="absolute top-full left-0 mt-2 w-96 opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50">
|
||||
<div class="p-4 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm shadow-xl">
|
||||
<div class="text-xs text-yellow-400 font-mono mb-3">
|
||||
{cycle.length} machines detected in a Thunderbolt Bridge loop:
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1.5 mb-4">
|
||||
{#each cycle as nodeId}
|
||||
<span class="px-2 py-1 text-xs font-mono text-yellow-300 bg-yellow-500/20 rounded border border-yellow-500/30">
|
||||
{getNodeName(nodeId)}
|
||||
</span>
|
||||
{/each}
|
||||
</div>
|
||||
<div class="text-xs text-white/60 mb-2">
|
||||
Run on each machine to disable bridge:
|
||||
</div>
|
||||
<div class="relative group/cmd">
|
||||
<pre class="text-xs font-mono text-white/90 bg-black/40 p-2.5 rounded border border-white/10 overflow-x-auto"><code>{disableCmd}</code></pre>
|
||||
<button
|
||||
type="button"
|
||||
class="absolute top-1.5 right-1.5 p-1 rounded bg-white/10 hover:bg-white/20 transition-colors opacity-0 group-hover/cmd:opacity-100"
|
||||
onclick={() => navigator.clipboard.writeText(disableCmd)}
|
||||
title="Copy command"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5 text-white/70" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
|
||||
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Chat Input - Below topology -->
|
||||
<div class="px-4 pt-6 pb-8">
|
||||
<div class="max-w-3xl mx-auto">
|
||||
@@ -1779,8 +1913,21 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
|
||||
<div class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden">
|
||||
|
||||
|
||||
<TopologyGraph highlightedNodes={highlightedNodes()} />
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
title="Thunderbolt Bridge cycle detected - click to view details"
|
||||
>
|
||||
<svg class="w-3 h-3 text-yellow-400" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-yellow-400">TB CYCLE</span>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</button>
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ dependencies = [
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -19,8 +19,11 @@ from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.models.model_cards import (
|
||||
MODEL_CARDS,
|
||||
ModelCard,
|
||||
ModelId,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
@@ -86,12 +89,12 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: str) -> ModelCard:
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
else:
|
||||
return await get_model_card(model_id)
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -236,7 +239,7 @@ class API:
|
||||
|
||||
async def get_placement(
|
||||
self,
|
||||
model_id: str,
|
||||
model_id: ModelId,
|
||||
sharding: Sharding = Sharding.Pipeline,
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
@@ -551,7 +554,7 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -578,7 +581,7 @@ class API:
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
|
||||
@@ -29,6 +29,7 @@ from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
@@ -45,6 +46,7 @@ from exo.utils.info_gatherer.info_gatherer import (
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
ThunderboltBridgeInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,6 +226,15 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
for key, value in state.node_thunderbolt.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_thunderbolt_bridge = {
|
||||
key: value
|
||||
for key, value in state.node_thunderbolt_bridge.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
# Recompute cycles after removing the node
|
||||
thunderbolt_bridge_cycles = topology.get_thunderbolt_bridge_cycles(
|
||||
node_thunderbolt_bridge
|
||||
)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
@@ -234,6 +245,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"node_system": node_system,
|
||||
"node_network": node_network,
|
||||
"node_thunderbolt": node_thunderbolt,
|
||||
"node_thunderbolt_bridge": node_thunderbolt_bridge,
|
||||
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -311,6 +324,16 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
|
||||
case ThunderboltBridgeInfo():
|
||||
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
|
||||
**state.node_thunderbolt_bridge,
|
||||
event.node_id: info.status,
|
||||
}
|
||||
update["node_thunderbolt_bridge"] = new_tb_bridge
|
||||
# Recompute cycles with updated bridge status
|
||||
update["thunderbolt_bridge_cycles"] = (
|
||||
topology.get_thunderbolt_bridge_cycles(new_tb_bridge)
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from pydantic import PositiveInt
|
||||
from typing import Annotated
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
_card_cache: dict[str, "ModelCard"] = {}
|
||||
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
@@ -20,6 +22,43 @@ class ModelCard(CamelCaseModel):
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
|
||||
async def save(self, path: Path) -> None:
|
||||
async with await open_file(path, "w") as f:
|
||||
py = self.model_dump()
|
||||
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
|
||||
await f.write(data)
|
||||
|
||||
@staticmethod
|
||||
async def load_from_path(path: Path) -> "ModelCard":
|
||||
async with await open_file(path, "r") as f:
|
||||
py = tomlkit.loads(await f.read())
|
||||
return ModelCard.model_validate(py)
|
||||
|
||||
@staticmethod
|
||||
async def load(model_id: ModelId) -> "ModelCard":
|
||||
if model_id in MODEL_CARDS:
|
||||
return MODEL_CARDS[model_id]
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@staticmethod
|
||||
async def from_hf(model_id: ModelId) -> "ModelCard":
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
if (mc := _card_cache.get(model_id)) is not None:
|
||||
return mc
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
|
||||
mc = ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
)
|
||||
_card_cache[model_id] = mc
|
||||
return mc
|
||||
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
@@ -308,3 +347,99 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
}
|
||||
|
||||
from exo.worker.download.download_utils import ( # noqa: E402
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
architectures: list[str] | None = None
|
||||
|
||||
@property
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
["Qwen3MoeForCausalLM"],
|
||||
["MiniMaxM2ForCausalLM"],
|
||||
["LlamaForCausalLM"],
|
||||
["GptOssForCausalLM"],
|
||||
]
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(config_path, "r") as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(index_path, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return Memory.from_bytes(metadata.total_size)
|
||||
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is None:
|
||||
raise ValueError(f"No safetensors info found for {model_id}")
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
from typing import Annotated
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
|
||||
|
||||
async def get_config_data(model_id: str) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(config_path, "r") as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def get_safetensors_size(model_id: str) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(index_path, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return Memory.from_bytes(metadata.total_size)
|
||||
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is None:
|
||||
raise ValueError(f"No safetensors info found for {model_id}")
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
|
||||
_model_card_cache: dict[str, ModelCard] = {}
|
||||
|
||||
|
||||
async def get_model_card(model_id: str) -> ModelCard:
|
||||
if model_id in _model_card_cache:
|
||||
return _model_card_cache[model_id]
|
||||
model_card = await _get_model_card(model_id)
|
||||
_model_card_cache[model_id] = model_card
|
||||
return model_card
|
||||
|
||||
|
||||
async def _get_model_card(model_id: str) -> ModelCard:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.supports_tensor if model_card is not None else False,
|
||||
)
|
||||
@@ -7,6 +7,7 @@ import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import ThunderboltBridgeStatus
|
||||
from exo.shared.types.topology import (
|
||||
Connection,
|
||||
Cycle,
|
||||
@@ -234,3 +235,31 @@ class Topology:
|
||||
if not has_tb:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_thunderbolt_bridge_cycles(
|
||||
self,
|
||||
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
|
||||
) -> list[list[NodeId]]:
|
||||
"""
|
||||
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
|
||||
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
|
||||
2 or fewer nodes don't cause the broadcast storm problem.
|
||||
"""
|
||||
tb_cycles = self.get_cycles_tb()
|
||||
result: list[list[NodeId]] = []
|
||||
|
||||
for cycle in tb_cycles:
|
||||
node_ids = list(cycle.node_ids)
|
||||
# Only consider cycles with more than 2 nodes
|
||||
if len(node_ids) <= 2:
|
||||
continue
|
||||
# Check if all nodes in the cycle have TB bridge enabled
|
||||
all_enabled = all(
|
||||
node_id in node_tb_bridge_status
|
||||
and node_tb_bridge_status[node_id].enabled
|
||||
for node_id in node_ids
|
||||
)
|
||||
if all_enabled:
|
||||
result.append(node_ids)
|
||||
|
||||
return result
|
||||
|
||||
@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
model_id: str
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@@ -25,6 +25,14 @@ class NodeId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
|
||||
|
||||
class SessionId(CamelCaseModel):
|
||||
master_node_id: NodeId
|
||||
election_clock: int
|
||||
|
||||
@@ -71,3 +71,11 @@ class NodeThunderboltInfo(CamelCaseModel):
|
||||
"""Thunderbolt interface identifiers for a node."""
|
||||
|
||||
interfaces: Sequence[ThunderboltIdentifier] = []
|
||||
|
||||
|
||||
class ThunderboltBridgeStatus(CamelCaseModel):
|
||||
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
|
||||
|
||||
enabled: bool
|
||||
exists: bool
|
||||
service_name: str | None = None
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
SystemPerformanceProfile,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
@@ -51,6 +52,10 @@ class State(CamelCaseModel):
|
||||
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
|
||||
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
|
||||
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
|
||||
|
||||
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
|
||||
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
|
||||
|
||||
@field_serializer("topology", mode="plain")
|
||||
def _encode_topology(self, value: Topology) -> TopologySnapshot:
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
from datetime import timedelta
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
|
||||
DownloadProgress = (
|
||||
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
|
||||
)
|
||||
|
||||
|
||||
class ModelSafetensorsIndexMetadata(BaseModel):
|
||||
total_size: PositiveInt
|
||||
|
||||
|
||||
class ModelSafetensorsIndex(BaseModel):
|
||||
metadata: ModelSafetensorsIndexMetadata | None
|
||||
weight_map: dict[str, str]
|
||||
|
||||
|
||||
class FileListEntry(BaseModel):
|
||||
type: Literal["file", "directory"]
|
||||
path: str
|
||||
size: int | None = None
|
||||
|
||||
|
||||
class RepoFileDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
speed: float
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@@ -19,6 +19,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnection,
|
||||
@@ -34,6 +35,142 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
async def _get_thunderbolt_devices() -> set[str] | None:
|
||||
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listallhardwareports failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
output = result.stdout.decode()
|
||||
thunderbolt_devices: set[str] = set()
|
||||
current_port: str | None = None
|
||||
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("Hardware Port:"):
|
||||
current_port = line.split(":", 1)[1].strip()
|
||||
elif line.startswith("Device:") and current_port:
|
||||
device = line.split(":", 1)[1].strip()
|
||||
if "thunderbolt" in current_port.lower():
|
||||
thunderbolt_devices.add(device)
|
||||
current_port = None
|
||||
|
||||
return thunderbolt_devices
|
||||
|
||||
|
||||
async def _get_bridge_services() -> dict[str, str] | None:
|
||||
"""Get mapping of bridge device -> service name from network service order.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listnetworkserviceorder"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listnetworkserviceorder failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse service order to find bridge devices and their service names
|
||||
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
|
||||
service_order_output = result.stdout.decode()
|
||||
bridge_services: dict[str, str] = {} # device -> service name
|
||||
current_service: str | None = None
|
||||
|
||||
for line in service_order_output.splitlines():
|
||||
line = line.strip()
|
||||
# Match "(N) Service Name" or "(*) Service Name" (disabled)
|
||||
# but NOT "(Hardware Port: ...)" lines
|
||||
if (
|
||||
line
|
||||
and line.startswith("(")
|
||||
and ")" in line
|
||||
and not line.startswith("(Hardware Port:")
|
||||
):
|
||||
paren_end = line.index(")")
|
||||
if paren_end + 2 <= len(line):
|
||||
current_service = line[paren_end + 2 :]
|
||||
# Match "(Hardware Port: ..., Device: bridgeX)"
|
||||
elif current_service and "Device: bridge" in line:
|
||||
# Extract device name from "..., Device: bridge0)"
|
||||
device_start = line.find("Device: ") + len("Device: ")
|
||||
device_end = line.find(")", device_start)
|
||||
if device_end > device_start:
|
||||
device = line[device_start:device_end]
|
||||
bridge_services[device] = current_service
|
||||
|
||||
return bridge_services
|
||||
|
||||
|
||||
async def _get_bridge_members(bridge_device: str) -> set[str]:
|
||||
"""Get member interfaces of a bridge device via ifconfig."""
|
||||
result = await anyio.run_process(
|
||||
["ifconfig", bridge_device],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
|
||||
return set()
|
||||
|
||||
members: set[str] = set()
|
||||
ifconfig_output = result.stdout.decode()
|
||||
for line in ifconfig_output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("member:"):
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
members.add(parts[1])
|
||||
|
||||
return members
|
||||
|
||||
|
||||
async def _find_thunderbolt_bridge(
|
||||
bridge_services: dict[str, str], thunderbolt_devices: set[str]
|
||||
) -> str | None:
|
||||
"""Find the service name of a bridge containing Thunderbolt interfaces.
|
||||
|
||||
Returns the service name if found, None otherwise.
|
||||
"""
|
||||
for bridge_device, service_name in bridge_services.items():
|
||||
members = await _get_bridge_members(bridge_device)
|
||||
if members & thunderbolt_devices: # intersection is non-empty
|
||||
return service_name
|
||||
return None
|
||||
|
||||
|
||||
async def _is_service_enabled(service_name: str) -> bool | None:
|
||||
"""Check if a network service is enabled.
|
||||
|
||||
Returns True if enabled, False if disabled, None on error.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-getnetworkserviceenabled", service_name],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -getnetworkserviceenabled '{service_name}' "
|
||||
f"failed with code {result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
stdout = result.stdout.decode().strip().lower()
|
||||
return stdout == "enabled"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
@@ -58,6 +195,64 @@ class MacThunderboltConnections(TaggedModel):
|
||||
conns: Sequence[ThunderboltConnection]
|
||||
|
||||
|
||||
class ThunderboltBridgeInfo(TaggedModel):
|
||||
status: ThunderboltBridgeStatus
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
"""Check if a Thunderbolt Bridge network service is enabled on this node.
|
||||
|
||||
Detection approach:
|
||||
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
|
||||
2. Find bridge devices from network service order (not hardware ports, as
|
||||
bridges may not appear there)
|
||||
3. Check each bridge's members via ifconfig
|
||||
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
|
||||
5. Check if that network service is enabled
|
||||
"""
|
||||
if not IS_DARWIN:
|
||||
return None
|
||||
|
||||
def _no_bridge_status() -> Self:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=False, service_name=None
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
tb_devices = await _get_thunderbolt_devices()
|
||||
if tb_devices is None:
|
||||
return _no_bridge_status()
|
||||
|
||||
bridge_services = await _get_bridge_services()
|
||||
if not bridge_services:
|
||||
return _no_bridge_status()
|
||||
|
||||
tb_service_name = await _find_thunderbolt_bridge(bridge_services, tb_devices)
|
||||
if not tb_service_name:
|
||||
return _no_bridge_status()
|
||||
|
||||
enabled = await _is_service_enabled(tb_service_name)
|
||||
if enabled is None:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=True, service_name=tb_service_name
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=enabled,
|
||||
exists=True,
|
||||
service_name=tb_service_name,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
@@ -111,6 +306,7 @@ GatheredInfo = (
|
||||
| NodeNetworkInterfaces
|
||||
| MacThunderboltIdentifiers
|
||||
| MacThunderboltConnections
|
||||
| ThunderboltBridgeInfo
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
@@ -125,6 +321,7 @@ class InfoGatherer:
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
@@ -133,6 +330,7 @@ class InfoGatherer:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
|
||||
tg.start_soon(self._monitor_thunderbolt_bridge_status)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
@@ -206,6 +404,17 @@ class InfoGatherer:
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
|
||||
@@ -17,17 +17,20 @@ import aiohttp
|
||||
import certifi
|
||||
from loguru import logger
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
DirectoryPath,
|
||||
Field,
|
||||
PositiveInt,
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import DownloadProgressData
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadProgressData,
|
||||
FileListEntry,
|
||||
ModelSafetensorsIndex,
|
||||
RepoDownloadProgress,
|
||||
RepoFileDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
@@ -37,53 +40,6 @@ from exo.worker.download.huggingface_utils import (
|
||||
)
|
||||
|
||||
|
||||
class ModelSafetensorsIndexMetadata(BaseModel):
|
||||
total_size: PositiveInt
|
||||
|
||||
|
||||
class ModelSafetensorsIndex(BaseModel):
|
||||
metadata: ModelSafetensorsIndexMetadata | None
|
||||
weight_map: dict[str, str]
|
||||
|
||||
|
||||
class FileListEntry(BaseModel):
|
||||
type: Literal["file", "directory"]
|
||||
path: str
|
||||
size: int | None = None
|
||||
|
||||
|
||||
class RepoFileDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
speed: float
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
def trim_etag(etag: str) -> str:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
return etag[1:-1]
|
||||
@@ -125,12 +81,12 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
)
|
||||
|
||||
|
||||
def build_model_path(model_id: str) -> DirectoryPath:
|
||||
return EXO_MODELS_DIR / model_id.replace("/", "--")
|
||||
def build_model_path(model_id: ModelId) -> DirectoryPath:
|
||||
return EXO_MODELS_DIR / model_id.normalize()
|
||||
|
||||
|
||||
async def resolve_model_path_for_repo(repo_id: str) -> Path:
|
||||
return (await ensure_models_dir()) / repo_id.replace("/", "--")
|
||||
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
|
||||
return (await ensure_models_dir()) / model_id.normalize()
|
||||
|
||||
|
||||
async def ensure_models_dir() -> Path:
|
||||
@@ -138,8 +94,8 @@ async def ensure_models_dir() -> Path:
|
||||
return EXO_MODELS_DIR
|
||||
|
||||
|
||||
async def delete_model(repo_id: str) -> bool:
|
||||
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
|
||||
async def delete_model(model_id: ModelId) -> bool:
|
||||
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
@@ -164,19 +120,17 @@ async def seed_models(seed_dir: str | Path):
|
||||
|
||||
|
||||
async def fetch_file_list_with_cache(
|
||||
repo_id: str, revision: str = "main", recursive: bool = False
|
||||
model_id: ModelId, revision: str = "main", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
target_dir = (
|
||||
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
|
||||
)
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = (
|
||||
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||
@@ -184,25 +138,25 @@ async def fetch_file_list_with_cache(
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
||||
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(repo_id, revision, path, recursive)
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
||||
raise Exception(
|
||||
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
|
||||
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_file_list(
|
||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
||||
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_download_headers()
|
||||
@@ -219,7 +173,7 @@ async def _fetch_file_list(
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
repo_id, revision, item.path, recursive
|
||||
model_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
@@ -276,10 +230,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
|
||||
|
||||
|
||||
async def file_meta(
|
||||
repo_id: str, revision: str, path: str, redirected_location: str | None = None
|
||||
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
|
||||
) -> tuple[int, str]:
|
||||
url = (
|
||||
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||
if redirected_location is None
|
||||
else f"{get_hf_endpoint()}{redirected_location}"
|
||||
)
|
||||
@@ -298,7 +252,7 @@ async def file_meta(
|
||||
return content_length, etag
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(repo_id, revision, path, redirected_location)
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
@@ -310,7 +264,7 @@ async def file_meta(
|
||||
|
||||
|
||||
async def download_file_with_retry(
|
||||
repo_id: str,
|
||||
model_id: ModelId,
|
||||
revision: str,
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
@@ -320,23 +274,23 @@ async def download_file_with_retry(
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
repo_id, revision, path, target_dir, on_progress
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
raise e
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
||||
raise Exception(
|
||||
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
|
||||
|
||||
async def _download_file(
|
||||
repo_id: str,
|
||||
model_id: ModelId,
|
||||
revision: str,
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
@@ -345,7 +299,7 @@ async def _download_file(
|
||||
if await aios.path.exists(target_dir / path):
|
||||
return target_dir / path
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(repo_id, revision, path)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
partial_path = target_dir / f"{path}.partial"
|
||||
resume_byte_pos = (
|
||||
@@ -354,7 +308,7 @@ async def _download_file(
|
||||
else None
|
||||
)
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||
headers = await get_download_headers()
|
||||
if resume_byte_pos:
|
||||
headers["Range"] = f"bytes={resume_byte_pos}-"
|
||||
@@ -394,7 +348,7 @@ async def _download_file(
|
||||
|
||||
def calculate_repo_progress(
|
||||
shard: ShardMetadata,
|
||||
repo_id: str,
|
||||
model_id: ModelId,
|
||||
revision: str,
|
||||
file_progress: dict[str, RepoFileDownloadProgress],
|
||||
all_start_time: float,
|
||||
@@ -423,7 +377,7 @@ def calculate_repo_progress(
|
||||
else "not_started"
|
||||
)
|
||||
return RepoDownloadProgress(
|
||||
repo_id=repo_id,
|
||||
repo_id=model_id,
|
||||
repo_revision=revision,
|
||||
shard=shard,
|
||||
completed_files=len(
|
||||
@@ -442,11 +396,11 @@ def calculate_repo_progress(
|
||||
)
|
||||
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
model_id, revision, "model.safetensors.index.json", target_dir
|
||||
)
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
@@ -504,7 +458,7 @@ async def download_shard(
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
str(shard.model_card.model_id), revision, recursive=True
|
||||
shard.model_card.model_id, revision, recursive=True
|
||||
)
|
||||
filtered_file_list = list(
|
||||
filter_repo_objects(
|
||||
@@ -538,7 +492,7 @@ async def download_shard(
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(curr_bytes),
|
||||
@@ -555,7 +509,7 @@ async def download_shard(
|
||||
shard,
|
||||
calculate_repo_progress(
|
||||
shard,
|
||||
str(shard.model_card.model_id),
|
||||
shard.model_card.model_id,
|
||||
revision,
|
||||
file_progress,
|
||||
all_start_time,
|
||||
@@ -565,7 +519,7 @@ async def download_shard(
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(downloaded_bytes),
|
||||
@@ -589,7 +543,7 @@ async def download_shard(
|
||||
async def download_with_semaphore(file: FileListEntry) -> None:
|
||||
async with semaphore:
|
||||
await download_file_with_retry(
|
||||
str(shard.model_card.model_id),
|
||||
shard.model_card.model_id,
|
||||
revision,
|
||||
file.path,
|
||||
target_dir,
|
||||
@@ -603,7 +557,7 @@ async def download_shard(
|
||||
*[download_with_semaphore(file) for file in filtered_file_list]
|
||||
)
|
||||
final_repo_progress = calculate_repo_progress(
|
||||
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
|
||||
shard, shard.model_card.model_id, revision, file_progress, all_start_time
|
||||
)
|
||||
await on_progress(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
||||
|
||||
@@ -3,8 +3,7 @@ from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -19,8 +18,8 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
)
|
||||
|
||||
|
||||
async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
model_card = await get_model_card(model_id)
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -31,7 +30,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
)
|
||||
|
||||
|
||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
||||
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
|
||||
base_shard = await build_base_shard(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=base_shard.model_card,
|
||||
@@ -148,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
self,
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
async def _status_for_model(
|
||||
model_id: str,
|
||||
model_id: ModelId,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
||||
shard = await build_full_shard(model_id)
|
||||
|
||||
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -67,16 +67,27 @@ def eval_with_timeout(
|
||||
completed.set()
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
"""Structural type that any compatible layer must satisfy.
|
||||
|
||||
We require a single positional input of type ``mx.array`` and an
|
||||
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
|
||||
protocol matches the vast majority of `mlx.nn.Module` subclasses.
|
||||
"""
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
||||
|
||||
|
||||
class CustomMlxLayer(nn.Module):
|
||||
"""Base class for replacing an MLX layer with a custom implementation."""
|
||||
|
||||
def __init__(self, original_layer: nn.Module):
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
@property
|
||||
def original_layer(self) -> nn.Module:
|
||||
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, self["_original_layer"])
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -85,57 +96,56 @@ class CustomMlxLayer(nn.Module):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
original_layer = cast(_LayerCallable, self["_original_layer"])
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
def patch_pipeline_first_layer(
|
||||
pipeline_layer: nn.Module, group: mx.distributed.Group
|
||||
) -> nn.Module:
|
||||
cls = type(pipeline_layer)
|
||||
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.group = group
|
||||
|
||||
rank = group.rank()
|
||||
|
||||
class PatchedFirstLayer(cls):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if rank != 0:
|
||||
x = mx.distributed.recv_like(x, (rank - 1), group=group)
|
||||
return orig_call(self, x, *args, **kwargs)
|
||||
|
||||
pipeline_layer.__class__ = PatchedFirstLayer
|
||||
|
||||
return pipeline_layer
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
def patch_pipeline_last_layer(
|
||||
pipeline_layer: nn.Module, group: mx.distributed.Group
|
||||
) -> nn.Module:
|
||||
cls = type(pipeline_layer)
|
||||
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
||||
orig_call_sig = signature(orig_call)
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
s: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.s: int = s
|
||||
self.group = group
|
||||
self.original_layer_signature = signature(self.original_layer.__call__)
|
||||
|
||||
rank = group.rank()
|
||||
size = group.size()
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
class PatchedLastLayer(cls):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
|
||||
"cache", None
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output: mx.array = orig_call(self, x, *args, **kwargs)
|
||||
|
||||
if rank != size - 1:
|
||||
output = mx.distributed.send(output, (rank + 1) % size, group=group)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
return output
|
||||
|
||||
pipeline_layer.__class__ = PatchedLastLayer
|
||||
|
||||
return pipeline_layer
|
||||
return output
|
||||
|
||||
|
||||
def _inner_model(model: nn.Module) -> nn.Module:
|
||||
@@ -150,13 +160,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
|
||||
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
|
||||
|
||||
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[nn.Module]
|
||||
layers: list[_LayerCallable]
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
layers = cast(list[nn.Module], inner_model_instance.layers)
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
layers = cast(list[nn.Module], inner_model_instance.h)
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.h)
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
@@ -181,12 +191,15 @@ def pipeline_auto_parallel(
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
layers[0] = patch_pipeline_first_layer(layers[0], group)
|
||||
layers[-1] = patch_pipeline_last_layer(
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
group,
|
||||
device_rank,
|
||||
world_size,
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
@@ -321,7 +334,7 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
@@ -370,7 +383,6 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -433,7 +445,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
@@ -508,17 +520,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -552,7 +564,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return model
|
||||
@@ -586,7 +598,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -599,17 +611,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -661,7 +673,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x) # type: ignore
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -23,6 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
@@ -296,7 +297,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
"""
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
|
||||
@@ -320,7 +321,9 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
return None
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
|
||||
|
||||
@@ -11,14 +11,15 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
@@ -116,11 +117,12 @@ def run_gpt_oss_pipeline_device(
|
||||
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -182,11 +184,11 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -10,8 +10,8 @@ import pytest
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxLayer,
|
||||
patch_pipeline_first_layer,
|
||||
patch_pipeline_last_layer,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
patch_pipeline_model,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
@@ -50,8 +50,8 @@ def run_pipeline_device(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
mock = MockLayerInner()
|
||||
first = patch_pipeline_first_layer(mock, group)
|
||||
composed = patch_pipeline_last_layer(first, group)
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
|
||||
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
|
||||
inner_model = MockModel([composed])
|
||||
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
group = mx.distributed.init()
|
||||
|
||||
first = patch_pipeline_first_layer(mock, group)
|
||||
composed = patch_pipeline_last_layer(first, group)
|
||||
first = PipelineFirstLayer(mock, r=0, group=group)
|
||||
composed = PipelineLastLayer(first, r=0, s=1, group=group)
|
||||
|
||||
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert composed.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def download_tokenizer_files(model_id: str) -> Path:
|
||||
async def download_tokenizer_files(model_id: ModelId) -> Path:
|
||||
"""Download only the tokenizer-related files for a model."""
|
||||
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
|
||||
target_dir = await ensure_models_dir() / model_id.normalize()
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
||||
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: str) -> Path:
|
||||
|
||||
|
||||
# Get a sample of models to test (one per family to keep tests fast)
|
||||
def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
def get_test_models() -> list[ModelCard]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, tuple[str, ModelCard]] = {}
|
||||
for _, card in MODEL_CARDS.items():
|
||||
families: dict[str, ModelCard] = {}
|
||||
for card in MODEL_CARDS.values():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = card.model_id.short().split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
|
||||
if family not in families:
|
||||
families[family] = (card.model_id.short(), card)
|
||||
families[family] = card
|
||||
|
||||
return list(families.values())
|
||||
|
||||
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
TEST_MODELS: list[ModelCard] = get_test_models()
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
@@ -101,14 +101,13 @@ def event_loop():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
"model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode and decode text correctly."""
|
||||
model_id = str(model_card.model_id)
|
||||
model_id = model_card.model_id
|
||||
|
||||
# Download tokenizer files
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
@@ -167,16 +166,15 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
"model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_has_required_attributes(
|
||||
short_id: str, model_card: ModelCard
|
||||
) -> None:
|
||||
"""Test that tokenizer has required attributes for inference."""
|
||||
model_id = str(model_card.model_id)
|
||||
model_id = model_card.model_id
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -209,19 +207,18 @@ async def test_tokenizer_has_required_attributes(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
"model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
|
||||
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode text containing special tokens.
|
||||
|
||||
This is critical because the actual inference path uses prompts with
|
||||
special tokens from chat templates. If special tokens aren't handled
|
||||
correctly, encoding will fail.
|
||||
"""
|
||||
model_id = str(model_card.model_id)
|
||||
model_id = model_card.model_id
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -301,16 +298,14 @@ async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) ->
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "kimi" in short_id.lower()
|
||||
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
pytest.skip("No Kimi models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = kimi_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
model_card = kimi_models[0]
|
||||
model_id = model_card.model_id
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -349,17 +344,15 @@ async def test_kimi_tokenizer_specifically():
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "glm" in short_id.lower()
|
||||
glm_model_cards = [
|
||||
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not glm_models:
|
||||
if not glm_model_cards:
|
||||
pytest.skip("No GLM models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = glm_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
model_card = glm_model_cards[0]
|
||||
model_id = model_card.model_id
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import ModelId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
|
||||
12
uv.lock
generated
12
uv.lock
generated
@@ -248,6 +248,7 @@ dependencies = [
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
@@ -281,6 +282,7 @@ requires-dist = [
|
||||
{ name = "pydantic", specifier = ">=2.11.7" },
|
||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||
{ name = "tiktoken", specifier = ">=0.12.0" },
|
||||
{ name = "tomlkit", specifier = ">=0.14.0" },
|
||||
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
||||
]
|
||||
|
||||
@@ -315,6 +317,16 @@ dev = [
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
|
||||
]
|
||||
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
|
||||
Reference in New Issue
Block a user