Compare commits

...

22 Commits

Author SHA1 Message Date
Alex Cheema
56ae1e22e3 Remove unnecessary step comments
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 00:29:42 +00:00
Alex Cheema
215132c450 Refactor ThunderboltBridgeInfo.gather into smaller helper functions
Extract the monolithic gather() method into focused helpers:
- _get_thunderbolt_devices(): Get TB interface device names from hardware ports
- _get_bridge_services(): Get bridge device -> service name mapping
- _get_bridge_members(): Get member interfaces of a bridge via ifconfig
- _find_thunderbolt_bridge(): Find bridge containing TB interfaces
- _is_service_enabled(): Check if a network service is enabled

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 00:28:24 +00:00
Alex Cheema
82040e7e77 Increase notarization timeout to 20m
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 00:26:28 +00:00
Alex Cheema
86fb761c69 Trigger CI 2026-01-20 23:48:14 +00:00
Alex Cheema
08252a81c4 Fix Topology model to match actual JSON structure
The JSON from the backend has:
- nodes: array of node ID strings (not objects)
- connections: nested map { source: { sink: [edges] } } (not flat array)

Updated the Swift Topology model with custom decoder to handle
the actual JSON format and flatten connections for easier use.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 23:19:08 +00:00
Alex Cheema
3b62fa9c1f Fix ClusterState model to match new granular state structure
The backend split NodePerformanceProfile into granular mappings
(nodeIdentities, nodeMemory, nodeSystem, nodeThunderboltBridge).
Update the Swift model to decode these new fields.

Added computed nodeProfiles property for backwards compatibility.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 22:24:10 +00:00
Alex Cheema
ad7f5b746b Add periodic local network access checking
Re-check every 10 seconds so the warning disappears when user grants
permission. Robust error handling - won't crash on failures.

Changes:
- Add startPeriodicChecking(interval:) and stopPeriodicChecking()
- Only show "checking" status on first check to avoid UI flicker
- Reduce log noise by only logging state changes
- stop() now also stops periodic checking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:57:52 +00:00
Alex Cheema
bcec2bd5b0 Auto-cleanup legacy LaunchDaemon components on app startup
When users upgrade from older versions, the old LaunchDaemon plist and
scripts may interfere with the new dynamic TB bridge detection. This
change automatically removes those components on app startup.

Also handles legacy script path (disable_bridge_enable_dhcp.sh) from
even older versions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:56:24 +00:00
Alex Cheema
6ca54112c0 Fix TB bridge detection: find bridges from network service order
Bridges may not appear in -listallhardwareports but do appear in
-listnetworkserviceorder. Changed approach to:
1. Get Thunderbolt devices from -listallhardwareports
2. Get bridge devices from -listnetworkserviceorder (which also gives
   us the service name directly)
3. Check each bridge's members via ifconfig
4. If bridge has Thunderbolt members, check if service is enabled

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:50 +00:00
Alex Cheema
9267809d2b Fix TB bridge detection to use network service name instead of hardware port name
The Hardware Port name and Network Service name can be different. Use
-listnetworkserviceorder to find the actual service name for the bridge
device, then use that for -getnetworkserviceenabled.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:50 +00:00
Alex Cheema
9cd2dd48aa Add warning logs for Thunderbolt Bridge detection errors
Log warnings when networksetup or ifconfig commands fail during
Thunderbolt Bridge detection to aid debugging.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:50 +00:00
Alex Cheema
cadd66c1a1 Fix formatting for CI
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:50 +00:00
Alex Cheema
fb90e5dc1b Improve TB bridge detection robustness
- Detect TB bridge by checking if bridge contains Thunderbolt interfaces
  rather than assuming bridge0 is always Thunderbolt
- Find Thunderbolt interface devices from hardware ports
- Check bridge members using ifconfig to verify Thunderbolt membership
- Add service_name field to ThunderboltBridgeStatus
- Dashboard shows correct command with actual service name (e.g. "TB Bridge")

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:49 +00:00
Alex Cheema
844b8f6739 Show TB bridge status per node in debug mode
When debug mode is enabled, each node in the topology now shows its
Thunderbolt Bridge status:
- "TB:ON" (yellow) when bridge is enabled
- "TB:OFF" (grey) when bridge is disabled

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:49 +00:00
Alex Cheema
069be5fa23 Fix formatting in apply.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:49 +00:00
Alex Cheema
fcc92e5d5e Add Thunderbolt Bridge cycle warning to dashboard
Shows a yellow warning badge on the topology when a TB bridge cycle
is detected. On hover:
- Highlights the affected nodes in yellow on the topology
- Shows which machines are in the cycle
- Provides a copy-paste command to disable bridge

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:49 +00:00
Alex Cheema
63b74309f5 Replace LaunchDaemon with dynamic Thunderbolt Bridge loop detection
Instead of installing a launchd plist that runs periodically to disable
Thunderbolt Bridge, this change dynamically detects TB bridge loops at
runtime and prompts the user to disable the bridge via the macOS app.

Backend changes:
- Add ThunderboltBridgeStatus type to track bridge enabled state per node
- Add thunderbolt_bridge_cycles to State for detected problematic cycles
- Add get_thunderbolt_bridge_cycles() to Topology class
- Poll bridge status every 10s via networksetup command

Swift app changes:
- New ThunderboltBridgeService monitors for cycles and shows NSAlert
- Uses SCPreferencesCreateWithAuthorization for fine-grained permissions
- Remove automatic LaunchDaemon installation on app startup
- Clean up NetworkSetupHelper by removing daemon-related code

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:16:49 +00:00
rltakashige
758464703d Fix GPT OSS tensor sharding with upstream MLX LM (#1223)
## Motivation
MLX LM has given GPT OSS a shard method, but MLX does not have an update
to match.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 18:24:54 +00:00
rltakashige
9e2179c848 Register original layer in CustomMlxLayer (#1229)
## Motivation
Kimi K2 Thinking Pipeline RDMA was broken before.

## Why It Works
No clue tbh

## Test Plan

### Manual Testing
Kimi K2 Thinking and GPT OSS work at the same time on Pipeline RDMA.
Needs exo bench to check more thoroughly

### Automated Testing
Layer composition tests still pass.
2026-01-20 18:20:01 +00:00
Evan Quiney
22b5d836ef swap all instances of model_id: str for model_id: ModelId (#1221)
This change uses the stronger typed ModelId, and introduces some
convenience methods. It also cleans up some code left over from #1204.

## Changes

`model_id: str -> model_id: ModelId`
`repo_id: str -> model_id: ModelId`

Introduces methods on ModelId, in particular ModelId.normalize() to
replace `/` with `--`.

This PR did introduce some circular imports, so has moved some code
around to try and limit them.

## Test Plan

Tests still pass, types still check. As this is about metadata, I
haven't tested inference.
2026-01-20 17:38:06 +00:00
Alex Cheema
ea9c6d6bdf Remove dead local paths code from download_shard (#1227)
## Motivation

The `download_progress_for_local_path` function and the "Handle local
paths" code block in `download_shard` are dead code that cannot be
reached in normal usage. The code checks if `model_id` (e.g.,
"mlx-community/Llama-3.2-3B-Instruct-4bit") exists as a filesystem path,
but model IDs are constrained to HuggingFace repo format and there's no
API pathway to pass local paths.

## Changes

- Removed `download_progress_for_local_path()` function (45 lines)
- Removed the "Handle local paths" block in `download_shard()` (7 lines)

## Why It Works

This code was added in PR #669 as part of a "feature-local-models"
branch, but the feature was never fully integrated. The check
`aios.path.exists(str(shard.model_card.model_id))` would only return
true if a directory literally named
"mlx-community/Llama-3.2-3B-Instruct-4bit" existed in the cwd, which
doesn't happen in practice. Offline caching is already handled by
`fetch_file_list_with_cache`.

## Test Plan

### Manual Testing
- Run exo normally and verify downloads still work

### Automated Testing
- Existing tests pass (this code had no test coverage)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 17:07:27 +00:00
Alex Cheema
4ea66d427b Reduce download log spam (#1225)
## Motivation

When `skip_download=True`, exo was logging a lot of unnecessary messages during periodic download status checks. This resulted in spammy logs that made it hard to see important messages.

## Changes

- Only log "Downloading ... with allow_patterns=..." when actually downloading (not when skip_download is true)
- Changed periodic download progress check logs from INFO to DEBUG level

## Why It Works

The `skip_download=True` parameter is used when checking download status without actually downloading. By guarding the log behind `if not skip_download:`, we avoid logging on every status check. Changing the periodic emitting logs to DEBUG level reduces noise while still keeping them available for debugging.

## Test Plan

### Manual Testing
- Run exo and observe that logs are less spammy during normal operation
- Use -v or -vv flags to see DEBUG logs when needed

### Automated Testing
- Existing tests cover this code path
2026-01-20 16:57:05 +00:00
32 changed files with 1226 additions and 494 deletions

View File

@@ -315,7 +315,7 @@ jobs:
--apple-id "$APPLE_NOTARIZATION_USERNAME" \
--password "$APPLE_NOTARIZATION_PASSWORD" \
--team-id "$APPLE_NOTARIZATION_TEAM" \
--wait --timeout 15m 2>&1)
--wait --timeout 20m 2>&1)
echo "$SUBMISSION_OUTPUT"
SUBMISSION_ID=$(echo "$SUBMISSION_OUTPUT" | awk 'tolower($1)=="id:" && $2 ~ /^[0-9a-fA-F-]+$/ {print $2; exit}')

View File

@@ -14,6 +14,7 @@ struct ContentView: View {
@EnvironmentObject private var networkStatusService: NetworkStatusService
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
@EnvironmentObject private var updater: SparkleUpdater
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
@State private var focusedNode: NodeViewModel?
@State private var deletingInstanceIDs: Set<String> = []
@State private var showAllNodes = false

View File

@@ -21,6 +21,7 @@ struct EXOApp: App {
@StateObject private var networkStatusService: NetworkStatusService
@StateObject private var localNetworkChecker: LocalNetworkChecker
@StateObject private var updater: SparkleUpdater
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
private let terminationObserver: TerminationObserver
private let ciContext = CIContext(options: nil)
@@ -41,10 +42,13 @@ struct EXOApp: App {
let localNetwork = LocalNetworkChecker()
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
_updater = StateObject(wrappedValue: updater)
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
enableLaunchAtLoginIfNeeded()
NetworkSetupHelper.ensureLaunchDaemonInstalled()
// Check local network access BEFORE launching exo
localNetwork.check()
// Remove old LaunchDaemon components if they exist (from previous versions)
cleanupLegacyNetworkSetup()
// Check local network access periodically (warning disappears when user grants permission)
localNetwork.startPeriodicChecking(interval: 10)
controller.scheduleLaunch(after: 15)
service.startPolling()
networkStatus.startPolling()
@@ -58,6 +62,7 @@ struct EXOApp: App {
.environmentObject(networkStatusService)
.environmentObject(localNetworkChecker)
.environmentObject(updater)
.environmentObject(thunderboltBridgeService)
} label: {
menuBarIcon
}
@@ -130,6 +135,20 @@ struct EXOApp: App {
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
private func cleanupLegacyNetworkSetup() {
guard NetworkSetupHelper.hasInstalledComponents() else { return }
do {
try NetworkSetupHelper.uninstall()
Logger().info("Cleaned up legacy network setup components")
} catch {
// Non-fatal: user may have cancelled admin prompt or cleanup may have
// partially succeeded. The app will continue normally.
Logger().warning(
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
)
}
}
}
/// Helper for managing EXO's launch-at-login registration

View File

@@ -5,17 +5,43 @@ import Foundation
struct ClusterState: Decodable {
let instances: [String: ClusterInstance]
let runners: [String: RunnerStatusSummary]
let nodeProfiles: [String: NodeProfile]
let tasks: [String: ClusterTask]
let topology: Topology?
let downloads: [String: [NodeDownloadStatus]]
let thunderboltBridgeCycles: [[String]]
// Granular node state (split from the old nodeProfiles)
let nodeIdentities: [String: NodeIdentity]
let nodeMemory: [String: MemoryInfo]
let nodeSystem: [String: SystemInfo]
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
/// Computed property for backwards compatibility - merges granular state into NodeProfile
var nodeProfiles: [String: NodeProfile] {
var profiles: [String: NodeProfile] = [:]
let allNodeIds = Set(nodeIdentities.keys)
.union(nodeMemory.keys)
.union(nodeSystem.keys)
for nodeId in allNodeIds {
let identity = nodeIdentities[nodeId]
let memory = nodeMemory[nodeId]
let system = nodeSystem[nodeId]
profiles[nodeId] = NodeProfile(
modelId: identity?.modelId,
chipId: identity?.chipId,
friendlyName: identity?.friendlyName,
memory: memory,
system: system
)
}
return profiles
}
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
@@ -24,15 +50,34 @@ struct ClusterState: Decodable {
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
self.thunderboltBridgeCycles =
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
// Granular node state
self.nodeIdentities =
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
?? [:]
self.nodeMemory =
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
self.nodeSystem =
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
self.nodeThunderboltBridge =
try container.decodeIfPresent(
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
) ?? [:]
}
private enum CodingKeys: String, CodingKey {
case instances
case runners
case nodeProfiles
case topology
case tasks
case downloads
case thunderboltBridgeCycles
case nodeIdentities
case nodeMemory
case nodeSystem
case nodeThunderboltBridge
}
}
@@ -102,6 +147,18 @@ struct NodeProfile: Decodable {
let system: SystemInfo?
}
struct NodeIdentity: Decodable {
let modelId: String?
let chipId: String?
let friendlyName: String?
}
struct ThunderboltBridgeStatus: Decodable {
let enabled: Bool
let exists: Bool
let serviceName: String?
}
struct MemoryInfo: Decodable {
let ramTotal: MemoryValue?
let ramAvailable: MemoryValue?
@@ -120,16 +177,51 @@ struct SystemInfo: Decodable {
}
struct Topology: Decodable {
let nodes: [TopologyNode]
let connections: [TopologyConnection]?
/// Node IDs in the topology
let nodes: [String]
/// Flattened list of connections (source -> sink pairs)
let connections: [TopologyConnection]
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
// Connections come as nested map: { source: { sink: [edges] } }
// We flatten to array of (source, sink) pairs
var flatConnections: [TopologyConnection] = []
if let nested = try container.decodeIfPresent(
[String: [String: [AnyCodable]]].self, forKey: .connections
) {
for (source, sinks) in nested {
for sink in sinks.keys {
flatConnections.append(
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
}
}
}
self.connections = flatConnections
}
private enum CodingKeys: String, CodingKey {
case nodes
case connections
}
}
struct TopologyNode: Decodable {
let nodeId: String
let nodeProfile: NodeProfile
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
private struct AnyCodable: Decodable {
init(from decoder: Decoder) throws {
// Just consume the value without storing it
_ = try? decoder.singleValueContainer().decode(Bool.self)
_ = try? decoder.singleValueContainer().decode(Int.self)
_ = try? decoder.singleValueContainer().decode(Double.self)
_ = try? decoder.singleValueContainer().decode(String.self)
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
}
}
struct TopologyConnection: Decodable {
struct TopologyConnection {
let localNodeId: String
let sendBackNodeId: String
}

View File

@@ -41,6 +41,7 @@ final class LocalNetworkChecker: ObservableObject {
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
private var periodicTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
@@ -48,10 +49,39 @@ final class LocalNetworkChecker: ObservableObject {
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
/// Checks if local network access is working (one-time check).
func check() {
performCheck()
}
/// Starts periodic checking of local network access.
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
func startPeriodicChecking(interval: TimeInterval = 10) {
stopPeriodicChecking()
// Do an immediate check first
performCheck()
// Then schedule periodic checks
periodicTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
guard !Task.isCancelled else { break }
self?.performCheck()
}
}
}
/// Stops periodic checking.
func stopPeriodicChecking() {
periodicTask?.cancel()
periodicTask = nil
}
private func performCheck() {
checkTask?.cancel()
status = .checking
// Only show "checking" status on first check to avoid UI flicker
if status == .unknown {
status = .checking
}
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
@@ -60,12 +90,15 @@ final class LocalNetworkChecker: ObservableObject {
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
// Only log on state changes or first check to reduce noise
if isFirstCheck || result != self.status {
Self.logger.info("Local network check: \(result.displayText)")
}
}
}
@@ -141,6 +174,7 @@ final class LocalNetworkChecker: ObservableObject {
}
func stop() {
stopPeriodicChecking()
checkTask?.cancel()
checkTask = nil
connection?.cancel()

View File

@@ -7,48 +7,10 @@ enum NetworkSetupHelper {
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
// Legacy script path from older versions
private static let legacyScriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
do {
if daemonAlreadyInstalled() {
return
}
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
@@ -63,8 +25,9 @@ enum NetworkSetupHelper {
static func hasInstalledComponents() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
return scriptExists || plistExists
return scriptExists || legacyScriptExists || plistExists
}
private static func makeUninstallScript() -> String {
@@ -73,6 +36,7 @@ enum NetworkSetupHelper {
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
@@ -83,8 +47,9 @@ enum NetworkSetupHelper {
# Remove LaunchDaemon plist
rm -f "$PLIST_DEST"
# Remove the script and parent directory if empty
# Remove the script (current and legacy paths) and parent directory if empty
rm -f "$SCRIPT_DEST"
rm -f "$LEGACY_SCRIPT_DEST"
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
# Remove log files
@@ -107,90 +72,6 @@ enum NetworkSetupHelper {
"""
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
guard
let interval = plist["StartInterval"] as? Int,
interval == requiredStartInterval
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
}
private static func installLaunchDaemon() async throws {
let installerScript = makeInstallerScript()
try runShellAsAdmin(installerScript)
}
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript =
script

View File

@@ -0,0 +1,252 @@
import AppKit
import Combine
import Foundation
import Security
import SystemConfiguration
import os.log
@MainActor
final class ThunderboltBridgeService: ObservableObject {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
@Published private(set) var detectedCycle: [String]?
@Published private(set) var hasPromptedForCurrentCycle = false
@Published private(set) var lastError: String?
private weak var clusterStateService: ClusterStateService?
private var cancellables = Set<AnyCancellable>()
private var previousCycleSignature: String?
init(clusterStateService: ClusterStateService) {
self.clusterStateService = clusterStateService
setupObserver()
}
private func setupObserver() {
guard let service = clusterStateService else { return }
service.$latestSnapshot
.compactMap { $0 }
.sink { [weak self] snapshot in
self?.checkForCycles(snapshot: snapshot)
}
.store(in: &cancellables)
}
private func checkForCycles(snapshot: ClusterState) {
let cycles = snapshot.thunderboltBridgeCycles
// Only consider cycles with more than 2 nodes
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
// No problematic cycles detected, reset state
if detectedCycle != nil {
detectedCycle = nil
previousCycleSignature = nil
hasPromptedForCurrentCycle = false
}
return
}
// Create a signature for this cycle to detect if it changed
let cycleSignature = firstCycle.sorted().joined(separator: ",")
// If this is a new/different cycle, reset the prompt state
if cycleSignature != previousCycleSignature {
previousCycleSignature = cycleSignature
hasPromptedForCurrentCycle = false
}
detectedCycle = firstCycle
// Only prompt once per cycle
if !hasPromptedForCurrentCycle {
showDisableBridgePrompt(nodeIds: firstCycle)
}
}
private func showDisableBridgePrompt(nodeIds: [String]) {
hasPromptedForCurrentCycle = true
// Get friendly names for the nodes if available
let nodeNames = nodeIds.map { nodeId -> String in
if let snapshot = clusterStateService?.latestSnapshot,
let profile = snapshot.nodeProfiles[nodeId],
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
{
return friendlyName
}
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
}
let machineNames = nodeNames.joined(separator: ", ")
let alert = NSAlert()
alert.messageText = "Thunderbolt Bridge Loop Detected"
alert.informativeText = """
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Disable Bridge")
alert.addButton(withTitle: "Not Now")
let response = alert.runModal()
if response == .alertFirstButtonReturn {
Task {
await disableThunderboltBridge()
}
}
}
func disableThunderboltBridge() async {
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
lastError = nil
do {
try await disableThunderboltBridgeWithSCPreferences()
Self.logger.info("Successfully disabled Thunderbolt Bridge")
} catch {
Self.logger.error(
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
)
lastError = error.localizedDescription
showErrorAlert(message: error.localizedDescription)
}
}
private func disableThunderboltBridgeWithSCPreferences() async throws {
// 1. Create authorization reference
var authRef: AuthorizationRef?
var status = AuthorizationCreate(nil, nil, [], &authRef)
guard status == errAuthorizationSuccess, let authRef = authRef else {
throw ThunderboltBridgeError.authorizationFailed
}
defer { AuthorizationFree(authRef, [.destroyRights]) }
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled
}
throw ThunderboltBridgeError.authorizationDenied
}
// 3. Create SCPreferences with authorization
guard
let prefs = SCPreferencesCreateWithAuthorization(
kCFAllocatorDefault,
"EXO" as CFString,
nil,
authRef
)
else {
throw ThunderboltBridgeError.preferencesCreationFailed
}
// 4. Lock, modify, commit
guard SCPreferencesLock(prefs, true) else {
throw ThunderboltBridgeError.lockFailed
}
defer {
SCPreferencesUnlock(prefs)
}
// 5. Find and disable Thunderbolt Bridge service
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
throw ThunderboltBridgeError.servicesNotFound
}
var found = false
for service in allServices {
if let name = SCNetworkServiceGetName(service) as String?,
name == "Thunderbolt Bridge"
{
guard SCNetworkServiceSetEnabled(service, false) else {
throw ThunderboltBridgeError.disableFailed
}
found = true
Self.logger.info("Found and disabled Thunderbolt Bridge service")
break
}
}
if !found {
throw ThunderboltBridgeError.serviceNotFound
}
// 6. Commit and apply
guard SCPreferencesCommitChanges(prefs) else {
throw ThunderboltBridgeError.commitFailed
}
guard SCPreferencesApplyChanges(prefs) else {
throw ThunderboltBridgeError.applyFailed
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Failed to Disable Thunderbolt Bridge"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
}
enum ThunderboltBridgeError: LocalizedError {
case authorizationFailed
case authorizationCanceled
case authorizationDenied
case preferencesCreationFailed
case lockFailed
case servicesNotFound
case serviceNotFound
case disableFailed
case commitFailed
case applyFailed
var errorDescription: String? {
switch self {
case .authorizationFailed:
return "Failed to create authorization"
case .authorizationCanceled:
return "Authorization was canceled by user"
case .authorizationDenied:
return "Authorization was denied"
case .preferencesCreationFailed:
return "Failed to access network preferences"
case .lockFailed:
return "Failed to lock network preferences for modification"
case .servicesNotFound:
return "Could not retrieve network services"
case .serviceNotFound:
return "Thunderbolt Bridge service not found"
case .disableFailed:
return "Failed to disable Thunderbolt Bridge service"
case .commitFailed:
return "Failed to save network configuration changes"
case .applyFailed:
return "Failed to apply network configuration changes"
}
}
}

View File

@@ -86,7 +86,7 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let topologyNodeIds = Set(topology?.nodes ?? [])
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
@@ -95,8 +95,8 @@ extension ClusterState {
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
var orderedNodes: [NodeViewModel] = []
if let topologyNodes = topology?.nodes {
for topoNode in topologyNodes {
if let viewModel = nodesById[topoNode.nodeId] {
for nodeId in topologyNodes {
if let viewModel = nodesById[nodeId] {
orderedNodes.append(viewModel)
}
}
@@ -116,7 +116,7 @@ extension ClusterState {
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections?.compactMap { connection in
topology?.connections.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode, nodeThunderboltBridge, type NodeInfo } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -16,6 +16,7 @@ import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$li
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -932,6 +933,25 @@ function wrapLine(text: string, maxLen: number): string[] {
.attr('fill', 'rgba(179,179,179,0.7)')
.text(` (${ramUsagePercent.toFixed(0)}%)`);
}
// Debug mode: Show TB bridge status
if (debugEnabled) {
const tbStatus = tbBridgeData[nodeInfo.id];
if (tbStatus) {
const tbY = nodeInfo.y + iconBaseHeight / 2 + (showFullLabels ? 32 : showCompactLabels ? 26 : 22);
const tbFontSize = showFullLabels ? 9 : 7;
const tbColor = tbStatus.enabled ? 'rgba(234,179,8,0.9)' : 'rgba(100,100,100,0.7)';
const tbText = tbStatus.enabled ? 'TB:ON' : 'TB:OFF';
nodeG.append('text')
.attr('x', nodeInfo.x)
.attr('y', tbY)
.attr('text-anchor', 'middle')
.attr('fill', tbColor)
.attr('font-size', tbFontSize)
.attr('font-family', 'SF Mono, Monaco, monospace')
.text(tbText);
}
}
});
}

View File

@@ -173,6 +173,12 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface RawThunderboltBridgeStatus {
enabled: boolean;
exists: boolean;
serviceName?: string | null;
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -190,6 +196,10 @@ interface RawStateResponse {
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
// Thunderbolt bridge status per node
nodeThunderboltBridge?: Record<string, RawThunderboltBridgeStatus>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
}
export interface MessageAttachment {
@@ -387,6 +397,13 @@ class AppStore {
selectedPreviewModelId = $state<string | null>(null);
isLoadingPreviews = $state(false);
lastUpdate = $state<number | null>(null);
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderboltBridge = $state<
Record<
string,
{ enabled: boolean; exists: boolean; serviceName?: string | null }
>
>({});
// UI state
isTopologyMinimized = $state(false);
@@ -915,6 +932,16 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
if (data.thunderboltBridgeCycles) {
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles;
} else {
this.thunderboltBridgeCycles = [];
}
if (data.nodeThunderboltBridge) {
this.nodeThunderboltBridge = data.nodeThunderboltBridge;
} else {
this.nodeThunderboltBridge = {};
}
this.lastUpdate = Date.now();
} catch (error) {
console.error("Error fetching state:", error);
@@ -1620,6 +1647,8 @@ export const placementPreviews = () => appStore.placementPreviews;
export const selectedPreviewModelId = () => appStore.selectedPreviewModelId;
export const isLoadingPreviews = () => appStore.isLoadingPreviews;
export const lastUpdate = () => appStore.lastUpdate;
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const debugMode = () => appStore.getDebugMode();

View File

@@ -1,9 +1,9 @@
<script lang="ts">
import { TopologyGraph, ChatForm, ChatMessages, ChatSidebar, ModelCard } from '$lib/components';
import {
hasStartedChat,
isTopologyMinimized,
topologyData,
import {
hasStartedChat,
isTopologyMinimized,
topologyData,
lastUpdate,
clearChat,
instances,
@@ -22,6 +22,8 @@
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
thunderboltBridgeCycles,
nodeThunderboltBridge,
type DownloadProgress,
type PlacementPreview
} from '$lib/stores/app.svelte';
@@ -43,6 +45,25 @@
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const sidebarVisible = $derived(chatSidebarVisible());
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
const tbBridgeData = $derived(nodeThunderboltBridge());
// Helper to get friendly node name from node ID
function getNodeName(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
}
// Helper to get TB bridge service name from any node in the cycle
function getTbBridgeServiceName(cycleNodes: string[]): string {
for (const nodeId of cycleNodes) {
const status = tbBridgeData[nodeId];
if (status?.serviceName) {
return status.serviceName;
}
}
return 'Thunderbolt Bridge'; // fallback
}
let mounted = $state(false);
@@ -128,7 +149,10 @@ let chatScrollRef: HTMLDivElement | null = $state(null);
// Preview card hover state for highlighting nodes in topology
let hoveredPreviewNodes = $state<Set<string>>(new Set());
// Thunderbolt bridge cycle hover state for highlighting nodes in topology
let hoveredBridgeCycleNodes = $state<Set<string>>(new Set());
// Helper to unwrap tagged instance for hover highlighting
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
if (!instanceWrapped || typeof instanceWrapped !== 'object') return new Set();
@@ -151,7 +175,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
instanceDownloadExpandedNodes = next;
}
// Compute highlighted nodes from hovered instance or hovered preview
// Compute highlighted nodes from hovered instance, preview, or bridge cycle
const highlightedNodes = $derived(() => {
// First check instance hover
if (hoveredInstanceId) {
@@ -162,6 +186,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (hoveredPreviewNodes.size > 0) {
return hoveredPreviewNodes;
}
// Then check bridge cycle hover
if (hoveredBridgeCycleNodes.size > 0) {
return hoveredBridgeCycleNodes;
}
return new Set<string>();
});
@@ -1233,6 +1261,58 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div
class="absolute top-4 left-4 group"
role="alert"
onmouseenter={() => hoveredBridgeCycleNodes = new Set(cycle)}
onmouseleave={() => hoveredBridgeCycleNodes = new Set()}
>
<div class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-default">
<svg class="w-4 h-4 text-yellow-400 flex-shrink-0" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
<span class="text-xs font-mono text-yellow-400 tracking-wider">THUNDERBOLT BRIDGE CYCLE</span>
</div>
<div class="absolute top-full left-0 mt-2 w-96 opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50">
<div class="p-4 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm shadow-xl">
<div class="text-xs text-yellow-400 font-mono mb-3">
{cycle.length} machines detected in a Thunderbolt Bridge loop:
</div>
<div class="flex flex-wrap gap-1.5 mb-4">
{#each cycle as nodeId}
<span class="px-2 py-1 text-xs font-mono text-yellow-300 bg-yellow-500/20 rounded border border-yellow-500/30">
{getNodeName(nodeId)}
</span>
{/each}
</div>
<div class="text-xs text-white/60 mb-2">
Run on each machine to disable bridge:
</div>
<div class="relative group/cmd">
<pre class="text-xs font-mono text-white/90 bg-black/40 p-2.5 rounded border border-white/10 overflow-x-auto"><code>{disableCmd}</code></pre>
<button
type="button"
class="absolute top-1.5 right-1.5 p-1 rounded bg-white/10 hover:bg-white/20 transition-colors opacity-0 group-hover/cmd:opacity-100"
onclick={() => navigator.clipboard.writeText(disableCmd)}
title="Copy command"
>
<svg class="w-3.5 h-3.5 text-white/70" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
</svg>
</button>
</div>
</div>
</div>
</div>
{/if}
<!-- Exit topology-only mode button -->
<button
type="button"
@@ -1258,11 +1338,65 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<!-- Topology Container - Takes most of the space -->
<div class="flex-1 relative bg-exo-dark-gray/40 mx-4 mb-4 rounded-lg overflow-hidden">
<!-- The main topology graph - full container -->
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div
class="absolute top-4 left-4 group"
role="alert"
onmouseenter={() => hoveredBridgeCycleNodes = new Set(cycle)}
onmouseleave={() => hoveredBridgeCycleNodes = new Set()}
>
<!-- Warning Badge -->
<div class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-default">
<svg class="w-4 h-4 text-yellow-400 flex-shrink-0" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
<span class="text-xs font-mono text-yellow-400 tracking-wider">THUNDERBOLT BRIDGE CYCLE</span>
</div>
<!-- Hover Tooltip -->
<div class="absolute top-full left-0 mt-2 w-96 opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50">
<div class="p-4 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm shadow-xl">
<div class="text-xs text-yellow-400 font-mono mb-3">
{cycle.length} machines detected in a Thunderbolt Bridge loop:
</div>
<div class="flex flex-wrap gap-1.5 mb-4">
{#each cycle as nodeId}
<span class="px-2 py-1 text-xs font-mono text-yellow-300 bg-yellow-500/20 rounded border border-yellow-500/30">
{getNodeName(nodeId)}
</span>
{/each}
</div>
<div class="text-xs text-white/60 mb-2">
Run on each machine to disable bridge:
</div>
<div class="relative group/cmd">
<pre class="text-xs font-mono text-white/90 bg-black/40 p-2.5 rounded border border-white/10 overflow-x-auto"><code>{disableCmd}</code></pre>
<button
type="button"
class="absolute top-1.5 right-1.5 p-1 rounded bg-white/10 hover:bg-white/20 transition-colors opacity-0 group-hover/cmd:opacity-100"
onclick={() => navigator.clipboard.writeText(disableCmd)}
title="Copy command"
>
<svg class="w-3.5 h-3.5 text-white/70" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
</svg>
</button>
</div>
</div>
</div>
</div>
{/if}
</div>
<!-- Chat Input - Below topology -->
<div class="px-4 pt-6 pb-8">
<div class="max-w-3xl mx-auto">
@@ -1779,8 +1913,21 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
</div>
<div class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden">
<TopologyGraph highlightedNodes={highlightedNodes()} />
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
{#if tbBridgeCycles.length > 0}
<div
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
title="Thunderbolt Bridge cycle detected - click to view details"
>
<svg class="w-3 h-3 text-yellow-400" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
<span class="text-[10px] font-mono text-yellow-400">TB CYCLE</span>
</div>
{/if}
</div>
</button>

View File

@@ -24,6 +24,7 @@ dependencies = [
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -19,8 +19,11 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
)
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -86,12 +89,12 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: str) -> ModelCard:
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
else:
return await get_model_card(model_id)
return await ModelCard.from_hf(model_id)
class API:
@@ -236,7 +239,7 @@ class API:
async def get_placement(
self,
model_id: str,
model_id: ModelId,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
@@ -551,7 +554,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(payload.model)
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -578,7 +581,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(payload.model)
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(

View File

@@ -29,6 +29,7 @@ from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -45,6 +46,7 @@ from exo.utils.info_gatherer.info_gatherer import (
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
ThunderboltBridgeInfo,
)
@@ -224,6 +226,15 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
node_thunderbolt_bridge = {
key: value
for key, value in state.node_thunderbolt_bridge.items()
if key != event.node_id
}
# Recompute cycles after removing the node
thunderbolt_bridge_cycles = topology.get_thunderbolt_bridge_cycles(
node_thunderbolt_bridge
)
return state.model_copy(
update={
"downloads": downloads,
@@ -234,6 +245,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
"node_thunderbolt_bridge": node_thunderbolt_bridge,
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
}
)
@@ -311,6 +324,16 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
case ThunderboltBridgeInfo():
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
**state.node_thunderbolt_bridge,
event.node_id: info.status,
}
update["node_thunderbolt_bridge"] = new_tb_bridge
# Recompute cycles with updated bridge status
update["thunderbolt_bridge_cycles"] = (
topology.get_thunderbolt_bridge_cycles(new_tb_bridge)
)
return state.model_copy(update=update)

View File

@@ -1,16 +1,18 @@
from pydantic import PositiveInt
from typing import Annotated
from exo.shared.types.common import Id
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
_card_cache: dict[str, "ModelCard"] = {}
class ModelCard(CamelCaseModel):
@@ -20,6 +22,43 @@ class ModelCard(CamelCaseModel):
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
@@ -308,3 +347,99 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
@property
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],
["Qwen3MoeForCausalLM"],
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)

View File

@@ -1,122 +0,0 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_card(model_id: str) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -7,6 +7,7 @@ import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import ThunderboltBridgeStatus
from exo.shared.types.topology import (
Connection,
Cycle,
@@ -234,3 +235,31 @@ class Topology:
if not has_tb:
return False
return True
def get_thunderbolt_bridge_cycles(
self,
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
"""
tb_cycles = self.get_cycles_tb()
result: list[list[NodeId]] = []
for cycle in tb_cycles:
node_ids = list(cycle.node_ids)
# Only consider cycles with more than 2 nodes
if len(node_ids) <= 2:
continue
# Check if all nodes in the cycle have TB bridge enabled
all_enabled = all(
node_id in node_tb_bridge_status
and node_tb_bridge_status[node_id].enabled
for node_id in node_ids
)
if all_enabled:
result.append(node_ids)
return result

View File

@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: str
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1

View File

@@ -25,6 +25,14 @@ class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -71,3 +71,11 @@ class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []
class ThunderboltBridgeStatus(CamelCaseModel):
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
enabled: bool
exists: bool
service_name: str | None = None

View File

@@ -13,6 +13,7 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
ThunderboltBridgeStatus,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
@@ -51,6 +52,10 @@ class State(CamelCaseModel):
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:

View File

@@ -1,3 +1,8 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -19,6 +19,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
@@ -34,6 +35,142 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
IS_DARWIN = sys.platform == "darwin"
async def _get_thunderbolt_devices() -> set[str] | None:
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listallhardwareports"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listallhardwareports failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
output = result.stdout.decode()
thunderbolt_devices: set[str] = set()
current_port: str | None = None
for line in output.splitlines():
line = line.strip()
if line.startswith("Hardware Port:"):
current_port = line.split(":", 1)[1].strip()
elif line.startswith("Device:") and current_port:
device = line.split(":", 1)[1].strip()
if "thunderbolt" in current_port.lower():
thunderbolt_devices.add(device)
current_port = None
return thunderbolt_devices
async def _get_bridge_services() -> dict[str, str] | None:
"""Get mapping of bridge device -> service name from network service order.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listnetworkserviceorder"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listnetworkserviceorder failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
# Parse service order to find bridge devices and their service names
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
service_order_output = result.stdout.decode()
bridge_services: dict[str, str] = {} # device -> service name
current_service: str | None = None
for line in service_order_output.splitlines():
line = line.strip()
# Match "(N) Service Name" or "(*) Service Name" (disabled)
# but NOT "(Hardware Port: ...)" lines
if (
line
and line.startswith("(")
and ")" in line
and not line.startswith("(Hardware Port:")
):
paren_end = line.index(")")
if paren_end + 2 <= len(line):
current_service = line[paren_end + 2 :]
# Match "(Hardware Port: ..., Device: bridgeX)"
elif current_service and "Device: bridge" in line:
# Extract device name from "..., Device: bridge0)"
device_start = line.find("Device: ") + len("Device: ")
device_end = line.find(")", device_start)
if device_end > device_start:
device = line[device_start:device_end]
bridge_services[device] = current_service
return bridge_services
async def _get_bridge_members(bridge_device: str) -> set[str]:
"""Get member interfaces of a bridge device via ifconfig."""
result = await anyio.run_process(
["ifconfig", bridge_device],
check=False,
)
if result.returncode != 0:
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
return set()
members: set[str] = set()
ifconfig_output = result.stdout.decode()
for line in ifconfig_output.splitlines():
line = line.strip()
if line.startswith("member:"):
parts = line.split()
if len(parts) > 1:
members.add(parts[1])
return members
async def _find_thunderbolt_bridge(
bridge_services: dict[str, str], thunderbolt_devices: set[str]
) -> str | None:
"""Find the service name of a bridge containing Thunderbolt interfaces.
Returns the service name if found, None otherwise.
"""
for bridge_device, service_name in bridge_services.items():
members = await _get_bridge_members(bridge_device)
if members & thunderbolt_devices: # intersection is non-empty
return service_name
return None
async def _is_service_enabled(service_name: str) -> bool | None:
"""Check if a network service is enabled.
Returns True if enabled, False if disabled, None on error.
"""
result = await anyio.run_process(
["networksetup", "-getnetworkserviceenabled", service_name],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -getnetworkserviceenabled '{service_name}' "
f"failed with code {result.returncode}: {result.stderr.decode()}"
)
return None
stdout = result.stdout.decode().strip().lower()
return stdout == "enabled"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
@@ -58,6 +195,64 @@ class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class ThunderboltBridgeInfo(TaggedModel):
status: ThunderboltBridgeStatus
@classmethod
async def gather(cls) -> Self | None:
"""Check if a Thunderbolt Bridge network service is enabled on this node.
Detection approach:
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
2. Find bridge devices from network service order (not hardware ports, as
bridges may not appear there)
3. Check each bridge's members via ifconfig
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
5. Check if that network service is enabled
"""
if not IS_DARWIN:
return None
def _no_bridge_status() -> Self:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=False, service_name=None
)
)
try:
tb_devices = await _get_thunderbolt_devices()
if tb_devices is None:
return _no_bridge_status()
bridge_services = await _get_bridge_services()
if not bridge_services:
return _no_bridge_status()
tb_service_name = await _find_thunderbolt_bridge(bridge_services, tb_devices)
if not tb_service_name:
return _no_bridge_status()
enabled = await _is_service_enabled(tb_service_name)
if enabled is None:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=True, service_name=tb_service_name
)
)
return cls(
status=ThunderboltBridgeStatus(
enabled=enabled,
exists=True,
service_name=tb_service_name,
)
)
except Exception as e:
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
return None
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@@ -111,6 +306,7 @@ GatheredInfo = (
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| ThunderboltBridgeInfo
| NodeConfig
| MiscData
| StaticNodeInformation
@@ -125,6 +321,7 @@ class InfoGatherer:
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
@@ -133,6 +330,7 @@ class InfoGatherer:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._monitor_thunderbolt_bridge_status)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
@@ -206,6 +404,17 @@ class InfoGatherer:
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return

View File

@@ -17,17 +17,20 @@ import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.downloads import (
DownloadProgressData,
FileListEntry,
ModelSafetensorsIndex,
RepoDownloadProgress,
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -37,53 +40,6 @@ from exo.worker.download.huggingface_utils import (
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -125,12 +81,12 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.replace("/", "--")
def build_model_path(model_id: ModelId) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.normalize()
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
return (await ensure_models_dir()) / model_id.normalize()
async def ensure_models_dir() -> Path:
@@ -138,8 +94,8 @@ async def ensure_models_dir() -> Path:
return EXO_MODELS_DIR
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
@@ -164,19 +120,17 @@ async def seed_models(seed_dir: str | Path):
async def fetch_file_list_with_cache(
repo_id: str, revision: str = "main", recursive: bool = False
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
@@ -184,25 +138,25 @@ async def fetch_file_list_with_cache(
async def fetch_file_list_with_retry(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(repo_id, revision, path, recursive)
return await _fetch_file_list(model_id, revision, path, recursive)
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
async def _fetch_file_list(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_download_headers()
@@ -219,7 +173,7 @@ async def _fetch_file_list(
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
@@ -276,10 +230,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
async def file_meta(
repo_id: str, revision: str, path: str, redirected_location: str | None = None
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
) -> tuple[int, str]:
url = (
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
@@ -298,7 +252,7 @@ async def file_meta(
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
return await file_meta(model_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -310,7 +264,7 @@ async def file_meta(
async def download_file_with_retry(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -320,23 +274,23 @@ async def download_file_with_retry(
for attempt in range(n_attempts):
try:
return await _download_file(
repo_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
async def _download_file(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -345,7 +299,7 @@ async def _download_file(
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
@@ -354,7 +308,7 @@ async def _download_file(
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
@@ -394,7 +348,7 @@ async def _download_file(
def calculate_repo_progress(
shard: ShardMetadata,
repo_id: str,
model_id: ModelId,
revision: str,
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
@@ -423,7 +377,7 @@ def calculate_repo_progress(
else "not_started"
)
return RepoDownloadProgress(
repo_id=repo_id,
repo_id=model_id,
repo_revision=revision,
shard=shard,
completed_files=len(
@@ -442,11 +396,11 @@ def calculate_repo_progress(
)
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
model_id, revision, "model.safetensors.index.json", target_dir
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
@@ -477,53 +431,6 @@ async def get_downloaded_size(path: Path) -> int:
return 0
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
file_progress: dict[str, RepoFileDownloadProgress] = {}
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
repo_revision="local",
file_path=rel_path,
downloaded=Memory.from_bytes(size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(size),
speed=0,
eta=timedelta(0),
status="complete",
start_time=time.time(),
)
total_files += 1
total_bytes += size
else:
raise ValueError(f"Local path {local_path} is not a directory")
return RepoDownloadProgress(
repo_id=repo_id,
repo_revision="local",
shard=shard,
completed_files=total_files,
total_files=total_files,
downloaded_bytes=Memory.from_bytes(total_bytes),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(total_bytes),
overall_speed=0,
overall_eta=timedelta(0),
status="complete",
file_progress=file_progress,
)
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -534,14 +441,6 @@ async def download_shard(
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--"
@@ -552,13 +451,14 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -592,7 +492,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -609,7 +509,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_card.model_id),
shard.model_card.model_id,
revision,
file_progress,
all_start_time,
@@ -619,7 +519,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -643,7 +543,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_card.model_id),
shard.model_card.model_id,
revision,
file.path,
target_dir,
@@ -657,7 +557,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
shard, shard.model_card.model_id, revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,8 +3,7 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -19,8 +18,8 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
)
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -31,7 +30,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
@@ -148,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async def _status_for_model(
model_id: str,
model_id: ModelId,
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)

View File

@@ -83,11 +83,11 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
return cast(_LayerCallable, self["_original_layer"])
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -96,7 +96,7 @@ class CustomMlxLayer(nn.Module):
try:
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
original_layer = cast(_LayerCallable, self["_original_layer"])
return getattr(original_layer, name)
@@ -334,7 +334,7 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
@@ -383,7 +383,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")

View File

@@ -23,6 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
@@ -296,7 +297,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
@@ -320,7 +321,9 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.

View File

@@ -449,7 +449,7 @@ class Worker:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
@@ -480,7 +480,7 @@ class Worker:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -11,8 +11,9 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
return False
async def download_tokenizer_files(model_id: str) -> Path:
async def download_tokenizer_files(model_id: ModelId) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir = await ensure_models_dir() / model_id.normalize()
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: str) -> Path:
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (card.model_id.short(), card)
families[family] = card
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
TEST_MODELS: list[ModelCard] = get_test_models()
pytestmark = pytest.mark.slow
@@ -101,14 +101,13 @@ def event_loop():
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
@@ -167,16 +166,15 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -209,19 +207,18 @@ async def test_tokenizer_has_required_attributes(
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -301,16 +298,14 @@ async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) ->
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_card = kimi_models[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -349,17 +344,15 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
]
if not glm_models:
if not glm_model_cards:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_card = glm_model_cards[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)

View File

@@ -1,6 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress

12
uv.lock generated
View File

@@ -248,6 +248,7 @@ dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -281,6 +282,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
@@ -315,6 +317,16 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"