Compare commits

..

2 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
8f6f2f3065 Add fixes 2026-01-20 17:13:02 +00:00
Evan
e6af53c2ae foo 2026-01-20 17:12:31 +00:00
32 changed files with 512 additions and 1312 deletions

View File

@@ -315,7 +315,7 @@ jobs:
--apple-id "$APPLE_NOTARIZATION_USERNAME" \
--password "$APPLE_NOTARIZATION_PASSWORD" \
--team-id "$APPLE_NOTARIZATION_TEAM" \
--wait --timeout 20m 2>&1)
--wait --timeout 15m 2>&1)
echo "$SUBMISSION_OUTPUT"
SUBMISSION_ID=$(echo "$SUBMISSION_OUTPUT" | awk 'tolower($1)=="id:" && $2 ~ /^[0-9a-fA-F-]+$/ {print $2; exit}')

View File

@@ -14,7 +14,6 @@ struct ContentView: View {
@EnvironmentObject private var networkStatusService: NetworkStatusService
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
@EnvironmentObject private var updater: SparkleUpdater
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
@State private var focusedNode: NodeViewModel?
@State private var deletingInstanceIDs: Set<String> = []
@State private var showAllNodes = false

View File

@@ -21,7 +21,6 @@ struct EXOApp: App {
@StateObject private var networkStatusService: NetworkStatusService
@StateObject private var localNetworkChecker: LocalNetworkChecker
@StateObject private var updater: SparkleUpdater
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
private let terminationObserver: TerminationObserver
private let ciContext = CIContext(options: nil)
@@ -42,13 +41,10 @@ struct EXOApp: App {
let localNetwork = LocalNetworkChecker()
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
_updater = StateObject(wrappedValue: updater)
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
enableLaunchAtLoginIfNeeded()
// Remove old LaunchDaemon components if they exist (from previous versions)
cleanupLegacyNetworkSetup()
// Check local network access periodically (warning disappears when user grants permission)
localNetwork.startPeriodicChecking(interval: 10)
NetworkSetupHelper.ensureLaunchDaemonInstalled()
// Check local network access BEFORE launching exo
localNetwork.check()
controller.scheduleLaunch(after: 15)
service.startPolling()
networkStatus.startPolling()
@@ -62,7 +58,6 @@ struct EXOApp: App {
.environmentObject(networkStatusService)
.environmentObject(localNetworkChecker)
.environmentObject(updater)
.environmentObject(thunderboltBridgeService)
} label: {
menuBarIcon
}
@@ -135,20 +130,6 @@ struct EXOApp: App {
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
private func cleanupLegacyNetworkSetup() {
guard NetworkSetupHelper.hasInstalledComponents() else { return }
do {
try NetworkSetupHelper.uninstall()
Logger().info("Cleaned up legacy network setup components")
} catch {
// Non-fatal: user may have cancelled admin prompt or cleanup may have
// partially succeeded. The app will continue normally.
Logger().warning(
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
)
}
}
}
/// Helper for managing EXO's launch-at-login registration

View File

@@ -5,43 +5,17 @@ import Foundation
struct ClusterState: Decodable {
let instances: [String: ClusterInstance]
let runners: [String: RunnerStatusSummary]
let nodeProfiles: [String: NodeProfile]
let tasks: [String: ClusterTask]
let topology: Topology?
let downloads: [String: [NodeDownloadStatus]]
let thunderboltBridgeCycles: [[String]]
// Granular node state (split from the old nodeProfiles)
let nodeIdentities: [String: NodeIdentity]
let nodeMemory: [String: MemoryInfo]
let nodeSystem: [String: SystemInfo]
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
/// Computed property for backwards compatibility - merges granular state into NodeProfile
var nodeProfiles: [String: NodeProfile] {
var profiles: [String: NodeProfile] = [:]
let allNodeIds = Set(nodeIdentities.keys)
.union(nodeMemory.keys)
.union(nodeSystem.keys)
for nodeId in allNodeIds {
let identity = nodeIdentities[nodeId]
let memory = nodeMemory[nodeId]
let system = nodeSystem[nodeId]
profiles[nodeId] = NodeProfile(
modelId: identity?.modelId,
chipId: identity?.chipId,
friendlyName: identity?.friendlyName,
memory: memory,
system: system
)
}
return profiles
}
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
@@ -50,34 +24,15 @@ struct ClusterState: Decodable {
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
self.thunderboltBridgeCycles =
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
// Granular node state
self.nodeIdentities =
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
?? [:]
self.nodeMemory =
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
self.nodeSystem =
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
self.nodeThunderboltBridge =
try container.decodeIfPresent(
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
) ?? [:]
}
private enum CodingKeys: String, CodingKey {
case instances
case runners
case nodeProfiles
case topology
case tasks
case downloads
case thunderboltBridgeCycles
case nodeIdentities
case nodeMemory
case nodeSystem
case nodeThunderboltBridge
}
}
@@ -147,18 +102,6 @@ struct NodeProfile: Decodable {
let system: SystemInfo?
}
struct NodeIdentity: Decodable {
let modelId: String?
let chipId: String?
let friendlyName: String?
}
struct ThunderboltBridgeStatus: Decodable {
let enabled: Bool
let exists: Bool
let serviceName: String?
}
struct MemoryInfo: Decodable {
let ramTotal: MemoryValue?
let ramAvailable: MemoryValue?
@@ -177,51 +120,16 @@ struct SystemInfo: Decodable {
}
struct Topology: Decodable {
/// Node IDs in the topology
let nodes: [String]
/// Flattened list of connections (source -> sink pairs)
let connections: [TopologyConnection]
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
// Connections come as nested map: { source: { sink: [edges] } }
// We flatten to array of (source, sink) pairs
var flatConnections: [TopologyConnection] = []
if let nested = try container.decodeIfPresent(
[String: [String: [AnyCodable]]].self, forKey: .connections
) {
for (source, sinks) in nested {
for sink in sinks.keys {
flatConnections.append(
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
}
}
}
self.connections = flatConnections
}
private enum CodingKeys: String, CodingKey {
case nodes
case connections
}
let nodes: [TopologyNode]
let connections: [TopologyConnection]?
}
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
private struct AnyCodable: Decodable {
init(from decoder: Decoder) throws {
// Just consume the value without storing it
_ = try? decoder.singleValueContainer().decode(Bool.self)
_ = try? decoder.singleValueContainer().decode(Int.self)
_ = try? decoder.singleValueContainer().decode(Double.self)
_ = try? decoder.singleValueContainer().decode(String.self)
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
}
struct TopologyNode: Decodable {
let nodeId: String
let nodeProfile: NodeProfile
}
struct TopologyConnection {
struct TopologyConnection: Decodable {
let localNodeId: String
let sendBackNodeId: String
}

View File

@@ -41,7 +41,6 @@ final class LocalNetworkChecker: ObservableObject {
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
private var periodicTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
@@ -49,39 +48,10 @@ final class LocalNetworkChecker: ObservableObject {
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working (one-time check).
/// Checks if local network access is working.
func check() {
performCheck()
}
/// Starts periodic checking of local network access.
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
func startPeriodicChecking(interval: TimeInterval = 10) {
stopPeriodicChecking()
// Do an immediate check first
performCheck()
// Then schedule periodic checks
periodicTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
guard !Task.isCancelled else { break }
self?.performCheck()
}
}
}
/// Stops periodic checking.
func stopPeriodicChecking() {
periodicTask?.cancel()
periodicTask = nil
}
private func performCheck() {
checkTask?.cancel()
// Only show "checking" status on first check to avoid UI flicker
if status == .unknown {
status = .checking
}
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
@@ -90,15 +60,12 @@ final class LocalNetworkChecker: ObservableObject {
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
// Only log on state changes or first check to reduce noise
if isFirstCheck || result != self.status {
Self.logger.info("Local network check: \(result.displayText)")
}
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
@@ -174,7 +141,6 @@ final class LocalNetworkChecker: ObservableObject {
}
func stop() {
stopPeriodicChecking()
checkTask?.cancel()
checkTask = nil
connection?.cancel()

View File

@@ -7,10 +7,48 @@ enum NetworkSetupHelper {
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
// Legacy script path from older versions
private static let legacyScriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
do {
if daemonAlreadyInstalled() {
return
}
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
@@ -25,9 +63,8 @@ enum NetworkSetupHelper {
static func hasInstalledComponents() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
return scriptExists || legacyScriptExists || plistExists
return scriptExists || plistExists
}
private static func makeUninstallScript() -> String {
@@ -36,7 +73,6 @@ enum NetworkSetupHelper {
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
@@ -47,9 +83,8 @@ enum NetworkSetupHelper {
# Remove LaunchDaemon plist
rm -f "$PLIST_DEST"
# Remove the script (current and legacy paths) and parent directory if empty
# Remove the script and parent directory if empty
rm -f "$SCRIPT_DEST"
rm -f "$LEGACY_SCRIPT_DEST"
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
# Remove log files
@@ -72,6 +107,90 @@ enum NetworkSetupHelper {
"""
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
guard
let interval = plist["StartInterval"] as? Int,
interval == requiredStartInterval
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
}
private static func installLaunchDaemon() async throws {
let installerScript = makeInstallerScript()
try runShellAsAdmin(installerScript)
}
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript =
script

View File

@@ -1,252 +0,0 @@
import AppKit
import Combine
import Foundation
import Security
import SystemConfiguration
import os.log
@MainActor
final class ThunderboltBridgeService: ObservableObject {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
@Published private(set) var detectedCycle: [String]?
@Published private(set) var hasPromptedForCurrentCycle = false
@Published private(set) var lastError: String?
private weak var clusterStateService: ClusterStateService?
private var cancellables = Set<AnyCancellable>()
private var previousCycleSignature: String?
init(clusterStateService: ClusterStateService) {
self.clusterStateService = clusterStateService
setupObserver()
}
private func setupObserver() {
guard let service = clusterStateService else { return }
service.$latestSnapshot
.compactMap { $0 }
.sink { [weak self] snapshot in
self?.checkForCycles(snapshot: snapshot)
}
.store(in: &cancellables)
}
private func checkForCycles(snapshot: ClusterState) {
let cycles = snapshot.thunderboltBridgeCycles
// Only consider cycles with more than 2 nodes
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
// No problematic cycles detected, reset state
if detectedCycle != nil {
detectedCycle = nil
previousCycleSignature = nil
hasPromptedForCurrentCycle = false
}
return
}
// Create a signature for this cycle to detect if it changed
let cycleSignature = firstCycle.sorted().joined(separator: ",")
// If this is a new/different cycle, reset the prompt state
if cycleSignature != previousCycleSignature {
previousCycleSignature = cycleSignature
hasPromptedForCurrentCycle = false
}
detectedCycle = firstCycle
// Only prompt once per cycle
if !hasPromptedForCurrentCycle {
showDisableBridgePrompt(nodeIds: firstCycle)
}
}
private func showDisableBridgePrompt(nodeIds: [String]) {
hasPromptedForCurrentCycle = true
// Get friendly names for the nodes if available
let nodeNames = nodeIds.map { nodeId -> String in
if let snapshot = clusterStateService?.latestSnapshot,
let profile = snapshot.nodeProfiles[nodeId],
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
{
return friendlyName
}
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
}
let machineNames = nodeNames.joined(separator: ", ")
let alert = NSAlert()
alert.messageText = "Thunderbolt Bridge Loop Detected"
alert.informativeText = """
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Disable Bridge")
alert.addButton(withTitle: "Not Now")
let response = alert.runModal()
if response == .alertFirstButtonReturn {
Task {
await disableThunderboltBridge()
}
}
}
func disableThunderboltBridge() async {
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
lastError = nil
do {
try await disableThunderboltBridgeWithSCPreferences()
Self.logger.info("Successfully disabled Thunderbolt Bridge")
} catch {
Self.logger.error(
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
)
lastError = error.localizedDescription
showErrorAlert(message: error.localizedDescription)
}
}
private func disableThunderboltBridgeWithSCPreferences() async throws {
// 1. Create authorization reference
var authRef: AuthorizationRef?
var status = AuthorizationCreate(nil, nil, [], &authRef)
guard status == errAuthorizationSuccess, let authRef = authRef else {
throw ThunderboltBridgeError.authorizationFailed
}
defer { AuthorizationFree(authRef, [.destroyRights]) }
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled
}
throw ThunderboltBridgeError.authorizationDenied
}
// 3. Create SCPreferences with authorization
guard
let prefs = SCPreferencesCreateWithAuthorization(
kCFAllocatorDefault,
"EXO" as CFString,
nil,
authRef
)
else {
throw ThunderboltBridgeError.preferencesCreationFailed
}
// 4. Lock, modify, commit
guard SCPreferencesLock(prefs, true) else {
throw ThunderboltBridgeError.lockFailed
}
defer {
SCPreferencesUnlock(prefs)
}
// 5. Find and disable Thunderbolt Bridge service
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
throw ThunderboltBridgeError.servicesNotFound
}
var found = false
for service in allServices {
if let name = SCNetworkServiceGetName(service) as String?,
name == "Thunderbolt Bridge"
{
guard SCNetworkServiceSetEnabled(service, false) else {
throw ThunderboltBridgeError.disableFailed
}
found = true
Self.logger.info("Found and disabled Thunderbolt Bridge service")
break
}
}
if !found {
throw ThunderboltBridgeError.serviceNotFound
}
// 6. Commit and apply
guard SCPreferencesCommitChanges(prefs) else {
throw ThunderboltBridgeError.commitFailed
}
guard SCPreferencesApplyChanges(prefs) else {
throw ThunderboltBridgeError.applyFailed
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Failed to Disable Thunderbolt Bridge"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
}
enum ThunderboltBridgeError: LocalizedError {
case authorizationFailed
case authorizationCanceled
case authorizationDenied
case preferencesCreationFailed
case lockFailed
case servicesNotFound
case serviceNotFound
case disableFailed
case commitFailed
case applyFailed
var errorDescription: String? {
switch self {
case .authorizationFailed:
return "Failed to create authorization"
case .authorizationCanceled:
return "Authorization was canceled by user"
case .authorizationDenied:
return "Authorization was denied"
case .preferencesCreationFailed:
return "Failed to access network preferences"
case .lockFailed:
return "Failed to lock network preferences for modification"
case .servicesNotFound:
return "Could not retrieve network services"
case .serviceNotFound:
return "Thunderbolt Bridge service not found"
case .disableFailed:
return "Failed to disable Thunderbolt Bridge service"
case .commitFailed:
return "Failed to save network configuration changes"
case .applyFailed:
return "Failed to apply network configuration changes"
}
}
}

View File

@@ -86,7 +86,7 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes ?? [])
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
@@ -95,8 +95,8 @@ extension ClusterState {
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
var orderedNodes: [NodeViewModel] = []
if let topologyNodes = topology?.nodes {
for nodeId in topologyNodes {
if let viewModel = nodesById[nodeId] {
for topoNode in topologyNodes {
if let viewModel = nodesById[topoNode.nodeId] {
orderedNodes.append(viewModel)
}
}
@@ -116,7 +116,7 @@ extension ClusterState {
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections.compactMap { connection in
topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, nodeThunderboltBridge, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -16,7 +16,6 @@ import { topologyData, isTopologyMinimized, debugMode, nodeThunderboltBridge, ty
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -933,25 +932,6 @@ function wrapLine(text: string, maxLen: number): string[] {
.attr('fill', 'rgba(179,179,179,0.7)')
.text(` (${ramUsagePercent.toFixed(0)}%)`);
}
// Debug mode: Show TB bridge status
if (debugEnabled) {
const tbStatus = tbBridgeData[nodeInfo.id];
if (tbStatus) {
const tbY = nodeInfo.y + iconBaseHeight / 2 + (showFullLabels ? 32 : showCompactLabels ? 26 : 22);
const tbFontSize = showFullLabels ? 9 : 7;
const tbColor = tbStatus.enabled ? 'rgba(234,179,8,0.9)' : 'rgba(100,100,100,0.7)';
const tbText = tbStatus.enabled ? 'TB:ON' : 'TB:OFF';
nodeG.append('text')
.attr('x', nodeInfo.x)
.attr('y', tbY)
.attr('text-anchor', 'middle')
.attr('fill', tbColor)
.attr('font-size', tbFontSize)
.attr('font-family', 'SF Mono, Monaco, monospace')
.text(tbText);
}
}
});
}

View File

@@ -173,12 +173,6 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface RawThunderboltBridgeStatus {
enabled: boolean;
exists: boolean;
serviceName?: string | null;
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -196,10 +190,6 @@ interface RawStateResponse {
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
// Thunderbolt bridge status per node
nodeThunderboltBridge?: Record<string, RawThunderboltBridgeStatus>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
}
export interface MessageAttachment {
@@ -397,13 +387,6 @@ class AppStore {
selectedPreviewModelId = $state<string | null>(null);
isLoadingPreviews = $state(false);
lastUpdate = $state<number | null>(null);
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderboltBridge = $state<
Record<
string,
{ enabled: boolean; exists: boolean; serviceName?: string | null }
>
>({});
// UI state
isTopologyMinimized = $state(false);
@@ -932,16 +915,6 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
if (data.thunderboltBridgeCycles) {
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles;
} else {
this.thunderboltBridgeCycles = [];
}
if (data.nodeThunderboltBridge) {
this.nodeThunderboltBridge = data.nodeThunderboltBridge;
} else {
this.nodeThunderboltBridge = {};
}
this.lastUpdate = Date.now();
} catch (error) {
console.error("Error fetching state:", error);
@@ -1647,8 +1620,6 @@ export const placementPreviews = () => appStore.placementPreviews;
export const selectedPreviewModelId = () => appStore.selectedPreviewModelId;
export const isLoadingPreviews = () => appStore.isLoadingPreviews;
export const lastUpdate = () => appStore.lastUpdate;
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const debugMode = () => appStore.getDebugMode();

View File

@@ -1,9 +1,9 @@
<script lang="ts">
import { TopologyGraph, ChatForm, ChatMessages, ChatSidebar, ModelCard } from '$lib/components';
import {
hasStartedChat,
isTopologyMinimized,
topologyData,
import {
hasStartedChat,
isTopologyMinimized,
topologyData,
lastUpdate,
clearChat,
instances,
@@ -22,8 +22,6 @@
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
thunderboltBridgeCycles,
nodeThunderboltBridge,
type DownloadProgress,
type PlacementPreview
} from '$lib/stores/app.svelte';
@@ -45,25 +43,6 @@
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const sidebarVisible = $derived(chatSidebarVisible());
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
const tbBridgeData = $derived(nodeThunderboltBridge());
// Helper to get friendly node name from node ID
function getNodeName(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
}
// Helper to get TB bridge service name from any node in the cycle
function getTbBridgeServiceName(cycleNodes: string[]): string {
for (const nodeId of cycleNodes) {
const status = tbBridgeData[nodeId];
if (status?.serviceName) {
return status.serviceName;
}
}
return 'Thunderbolt Bridge'; // fallback
}
let mounted = $state(false);
@@ -149,10 +128,7 @@ let chatScrollRef: HTMLDivElement | null = $state(null);
// Preview card hover state for highlighting nodes in topology
let hoveredPreviewNodes = $state<Set<string>>(new Set());
// Thunderbolt bridge cycle hover state for highlighting nodes in topology
let hoveredBridgeCycleNodes = $state<Set<string>>(new Set());
// Helper to unwrap tagged instance for hover highlighting
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
if (!instanceWrapped || typeof instanceWrapped !== 'object') return new Set();
@@ -175,7 +151,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
instanceDownloadExpandedNodes = next;
}
// Compute highlighted nodes from hovered instance, preview, or bridge cycle
// Compute highlighted nodes from hovered instance or hovered preview
const highlightedNodes = $derived(() => {
// First check instance hover
if (hoveredInstanceId) {
@@ -186,10 +162,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (hoveredPreviewNodes.size > 0) {
return hoveredPreviewNodes;
}
// Then check bridge cycle hover
if (hoveredBridgeCycleNodes.size > 0) {
return hoveredBridgeCycleNodes;
}
return new Set<string>();
});
@@ -1261,58 +1233,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div
class="absolute top-4 left-4 group"
role="alert"
onmouseenter={() => hoveredBridgeCycleNodes = new Set(cycle)}
onmouseleave={() => hoveredBridgeCycleNodes = new Set()}
>
<div class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-default">
<svg class="w-4 h-4 text-yellow-400 flex-shrink-0" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
<span class="text-xs font-mono text-yellow-400 tracking-wider">THUNDERBOLT BRIDGE CYCLE</span>
</div>
<div class="absolute top-full left-0 mt-2 w-96 opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50">
<div class="p-4 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm shadow-xl">
<div class="text-xs text-yellow-400 font-mono mb-3">
{cycle.length} machines detected in a Thunderbolt Bridge loop:
</div>
<div class="flex flex-wrap gap-1.5 mb-4">
{#each cycle as nodeId}
<span class="px-2 py-1 text-xs font-mono text-yellow-300 bg-yellow-500/20 rounded border border-yellow-500/30">
{getNodeName(nodeId)}
</span>
{/each}
</div>
<div class="text-xs text-white/60 mb-2">
Run on each machine to disable bridge:
</div>
<div class="relative group/cmd">
<pre class="text-xs font-mono text-white/90 bg-black/40 p-2.5 rounded border border-white/10 overflow-x-auto"><code>{disableCmd}</code></pre>
<button
type="button"
class="absolute top-1.5 right-1.5 p-1 rounded bg-white/10 hover:bg-white/20 transition-colors opacity-0 group-hover/cmd:opacity-100"
onclick={() => navigator.clipboard.writeText(disableCmd)}
title="Copy command"
>
<svg class="w-3.5 h-3.5 text-white/70" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
</svg>
</button>
</div>
</div>
</div>
</div>
{/if}
<!-- Exit topology-only mode button -->
<button
type="button"
@@ -1338,65 +1258,11 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<!-- Topology Container - Takes most of the space -->
<div class="flex-1 relative bg-exo-dark-gray/40 mx-4 mb-4 rounded-lg overflow-hidden">
<!-- The main topology graph - full container -->
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div
class="absolute top-4 left-4 group"
role="alert"
onmouseenter={() => hoveredBridgeCycleNodes = new Set(cycle)}
onmouseleave={() => hoveredBridgeCycleNodes = new Set()}
>
<!-- Warning Badge -->
<div class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-default">
<svg class="w-4 h-4 text-yellow-400 flex-shrink-0" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
<span class="text-xs font-mono text-yellow-400 tracking-wider">THUNDERBOLT BRIDGE CYCLE</span>
</div>
<!-- Hover Tooltip -->
<div class="absolute top-full left-0 mt-2 w-96 opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50">
<div class="p-4 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm shadow-xl">
<div class="text-xs text-yellow-400 font-mono mb-3">
{cycle.length} machines detected in a Thunderbolt Bridge loop:
</div>
<div class="flex flex-wrap gap-1.5 mb-4">
{#each cycle as nodeId}
<span class="px-2 py-1 text-xs font-mono text-yellow-300 bg-yellow-500/20 rounded border border-yellow-500/30">
{getNodeName(nodeId)}
</span>
{/each}
</div>
<div class="text-xs text-white/60 mb-2">
Run on each machine to disable bridge:
</div>
<div class="relative group/cmd">
<pre class="text-xs font-mono text-white/90 bg-black/40 p-2.5 rounded border border-white/10 overflow-x-auto"><code>{disableCmd}</code></pre>
<button
type="button"
class="absolute top-1.5 right-1.5 p-1 rounded bg-white/10 hover:bg-white/20 transition-colors opacity-0 group-hover/cmd:opacity-100"
onclick={() => navigator.clipboard.writeText(disableCmd)}
title="Copy command"
>
<svg class="w-3.5 h-3.5 text-white/70" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
</svg>
</button>
</div>
</div>
</div>
</div>
{/if}
</div>
<!-- Chat Input - Below topology -->
<div class="px-4 pt-6 pb-8">
<div class="max-w-3xl mx-auto">
@@ -1913,21 +1779,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
</div>
<div class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden">
<TopologyGraph highlightedNodes={highlightedNodes()} />
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
{#if tbBridgeCycles.length > 0}
<div
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
title="Thunderbolt Bridge cycle detected - click to view details"
>
<svg class="w-3 h-3 text-yellow-400" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
<span class="text-[10px] font-mono text-yellow-400">TB CYCLE</span>
</div>
{/if}
</div>
</button>

View File

@@ -24,7 +24,6 @@ dependencies = [
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -19,11 +19,8 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -89,12 +86,12 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
async def resolve_model_card(model_id: str) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
else:
return await ModelCard.from_hf(model_id)
return await get_model_card(model_id)
class API:
@@ -239,7 +236,7 @@ class API:
async def get_placement(
self,
model_id: ModelId,
model_id: str,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
@@ -554,7 +551,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(
@@ -581,7 +578,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(

View File

@@ -29,7 +29,6 @@ from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -46,7 +45,6 @@ from exo.utils.info_gatherer.info_gatherer import (
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
ThunderboltBridgeInfo,
)
@@ -226,15 +224,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
node_thunderbolt_bridge = {
key: value
for key, value in state.node_thunderbolt_bridge.items()
if key != event.node_id
}
# Recompute cycles after removing the node
thunderbolt_bridge_cycles = topology.get_thunderbolt_bridge_cycles(
node_thunderbolt_bridge
)
return state.model_copy(
update={
"downloads": downloads,
@@ -245,8 +234,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
"node_thunderbolt_bridge": node_thunderbolt_bridge,
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
}
)
@@ -324,16 +311,6 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
case ThunderboltBridgeInfo():
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
**state.node_thunderbolt_bridge,
event.node_id: info.status,
}
update["node_thunderbolt_bridge"] = new_tb_bridge
# Recompute cycles with updated bridge status
update["thunderbolt_bridge_cycles"] = (
topology.get_thunderbolt_bridge_cycles(new_tb_bridge)
)
return state.model_copy(update=update)

View File

@@ -1,18 +1,16 @@
from typing import Annotated
from pydantic import PositiveInt
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
@@ -22,43 +20,6 @@ class ModelCard(CamelCaseModel):
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
@@ -347,99 +308,3 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
@property
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],
["Qwen3MoeForCausalLM"],
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)

View File

@@ -0,0 +1,122 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_card(model_id: str) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -7,7 +7,6 @@ import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import ThunderboltBridgeStatus
from exo.shared.types.topology import (
Connection,
Cycle,
@@ -235,31 +234,3 @@ class Topology:
if not has_tb:
return False
return True
def get_thunderbolt_bridge_cycles(
self,
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
"""
tb_cycles = self.get_cycles_tb()
result: list[list[NodeId]] = []
for cycle in tb_cycles:
node_ids = list(cycle.node_ids)
# Only consider cycles with more than 2 nodes
if len(node_ids) <= 2:
continue
# Check if all nodes in the cycle have TB bridge enabled
all_enabled = all(
node_id in node_tb_bridge_status
and node_tb_bridge_status[node_id].enabled
for node_id in node_ids
)
if all_enabled:
result.append(node_ids)
return result

View File

@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: ModelId
model_id: str
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1

View File

@@ -25,14 +25,6 @@ class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -71,11 +71,3 @@ class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []
class ThunderboltBridgeStatus(CamelCaseModel):
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
enabled: bool
exists: bool
service_name: str | None = None

View File

@@ -13,7 +13,6 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
ThunderboltBridgeStatus,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
@@ -52,10 +51,6 @@ class State(CamelCaseModel):
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:

View File

@@ -1,8 +1,3 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -47,50 +42,3 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -19,7 +19,6 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
@@ -35,142 +34,6 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
IS_DARWIN = sys.platform == "darwin"
async def _get_thunderbolt_devices() -> set[str] | None:
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listallhardwareports"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listallhardwareports failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
output = result.stdout.decode()
thunderbolt_devices: set[str] = set()
current_port: str | None = None
for line in output.splitlines():
line = line.strip()
if line.startswith("Hardware Port:"):
current_port = line.split(":", 1)[1].strip()
elif line.startswith("Device:") and current_port:
device = line.split(":", 1)[1].strip()
if "thunderbolt" in current_port.lower():
thunderbolt_devices.add(device)
current_port = None
return thunderbolt_devices
async def _get_bridge_services() -> dict[str, str] | None:
"""Get mapping of bridge device -> service name from network service order.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listnetworkserviceorder"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listnetworkserviceorder failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
# Parse service order to find bridge devices and their service names
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
service_order_output = result.stdout.decode()
bridge_services: dict[str, str] = {} # device -> service name
current_service: str | None = None
for line in service_order_output.splitlines():
line = line.strip()
# Match "(N) Service Name" or "(*) Service Name" (disabled)
# but NOT "(Hardware Port: ...)" lines
if (
line
and line.startswith("(")
and ")" in line
and not line.startswith("(Hardware Port:")
):
paren_end = line.index(")")
if paren_end + 2 <= len(line):
current_service = line[paren_end + 2 :]
# Match "(Hardware Port: ..., Device: bridgeX)"
elif current_service and "Device: bridge" in line:
# Extract device name from "..., Device: bridge0)"
device_start = line.find("Device: ") + len("Device: ")
device_end = line.find(")", device_start)
if device_end > device_start:
device = line[device_start:device_end]
bridge_services[device] = current_service
return bridge_services
async def _get_bridge_members(bridge_device: str) -> set[str]:
"""Get member interfaces of a bridge device via ifconfig."""
result = await anyio.run_process(
["ifconfig", bridge_device],
check=False,
)
if result.returncode != 0:
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
return set()
members: set[str] = set()
ifconfig_output = result.stdout.decode()
for line in ifconfig_output.splitlines():
line = line.strip()
if line.startswith("member:"):
parts = line.split()
if len(parts) > 1:
members.add(parts[1])
return members
async def _find_thunderbolt_bridge(
bridge_services: dict[str, str], thunderbolt_devices: set[str]
) -> str | None:
"""Find the service name of a bridge containing Thunderbolt interfaces.
Returns the service name if found, None otherwise.
"""
for bridge_device, service_name in bridge_services.items():
members = await _get_bridge_members(bridge_device)
if members & thunderbolt_devices: # intersection is non-empty
return service_name
return None
async def _is_service_enabled(service_name: str) -> bool | None:
"""Check if a network service is enabled.
Returns True if enabled, False if disabled, None on error.
"""
result = await anyio.run_process(
["networksetup", "-getnetworkserviceenabled", service_name],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -getnetworkserviceenabled '{service_name}' "
f"failed with code {result.returncode}: {result.stderr.decode()}"
)
return None
stdout = result.stdout.decode().strip().lower()
return stdout == "enabled"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
@@ -195,64 +58,6 @@ class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class ThunderboltBridgeInfo(TaggedModel):
status: ThunderboltBridgeStatus
@classmethod
async def gather(cls) -> Self | None:
"""Check if a Thunderbolt Bridge network service is enabled on this node.
Detection approach:
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
2. Find bridge devices from network service order (not hardware ports, as
bridges may not appear there)
3. Check each bridge's members via ifconfig
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
5. Check if that network service is enabled
"""
if not IS_DARWIN:
return None
def _no_bridge_status() -> Self:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=False, service_name=None
)
)
try:
tb_devices = await _get_thunderbolt_devices()
if tb_devices is None:
return _no_bridge_status()
bridge_services = await _get_bridge_services()
if not bridge_services:
return _no_bridge_status()
tb_service_name = await _find_thunderbolt_bridge(bridge_services, tb_devices)
if not tb_service_name:
return _no_bridge_status()
enabled = await _is_service_enabled(tb_service_name)
if enabled is None:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=True, service_name=tb_service_name
)
)
return cls(
status=ThunderboltBridgeStatus(
enabled=enabled,
exists=True,
service_name=tb_service_name,
)
)
except Exception as e:
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
return None
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@@ -306,7 +111,6 @@ GatheredInfo = (
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| ThunderboltBridgeInfo
| NodeConfig
| MiscData
| StaticNodeInformation
@@ -321,7 +125,6 @@ class InfoGatherer:
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
@@ -330,7 +133,6 @@ class InfoGatherer:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._monitor_thunderbolt_bridge_status)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
@@ -404,17 +206,6 @@ class InfoGatherer:
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return

View File

@@ -17,20 +17,17 @@ import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
DownloadProgressData,
FileListEntry,
ModelSafetensorsIndex,
RepoDownloadProgress,
RepoFileDownloadProgress,
)
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -40,6 +37,53 @@ from exo.worker.download.huggingface_utils import (
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -81,12 +125,12 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: ModelId) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.normalize()
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.replace("/", "--")
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
return (await ensure_models_dir()) / model_id.normalize()
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def ensure_models_dir() -> Path:
@@ -94,8 +138,8 @@ async def ensure_models_dir() -> Path:
return EXO_MODELS_DIR
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
@@ -120,17 +164,19 @@ async def seed_models(seed_dir: str | Path):
async def fetch_file_list_with_cache(
model_id: ModelId, revision: str = "main", recursive: bool = False
repo_id: str, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
@@ -138,25 +184,25 @@ async def fetch_file_list_with_cache(
async def fetch_file_list_with_retry(
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
return await _fetch_file_list(repo_id, revision, path, recursive)
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
)
async def _fetch_file_list(
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_download_headers()
@@ -173,7 +219,7 @@ async def _fetch_file_list(
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
model_id, revision, item.path, recursive
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
@@ -230,10 +276,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
async def file_meta(
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
repo_id: str, revision: str, path: str, redirected_location: str | None = None
) -> tuple[int, str]:
url = (
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
@@ -252,7 +298,7 @@ async def file_meta(
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(model_id, revision, path, redirected_location)
return await file_meta(repo_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -264,7 +310,7 @@ async def file_meta(
async def download_file_with_retry(
model_id: ModelId,
repo_id: str,
revision: str,
path: str,
target_dir: Path,
@@ -274,23 +320,23 @@ async def download_file_with_retry(
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress
repo_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
)
async def _download_file(
model_id: ModelId,
repo_id: str,
revision: str,
path: str,
target_dir: Path,
@@ -299,7 +345,7 @@ async def _download_file(
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
length, etag = await file_meta(repo_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
@@ -308,7 +354,7 @@ async def _download_file(
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
@@ -348,7 +394,7 @@ async def _download_file(
def calculate_repo_progress(
shard: ShardMetadata,
model_id: ModelId,
repo_id: str,
revision: str,
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
@@ -377,7 +423,7 @@ def calculate_repo_progress(
else "not_started"
)
return RepoDownloadProgress(
repo_id=model_id,
repo_id=repo_id,
repo_revision=revision,
shard=shard,
completed_files=len(
@@ -396,11 +442,11 @@ def calculate_repo_progress(
)
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / model_id.normalize()
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
model_id, revision, "model.safetensors.index.json", target_dir
repo_id, revision, "model.safetensors.index.json", target_dir
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
@@ -458,7 +504,7 @@ async def download_shard(
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
shard.model_card.model_id, revision, recursive=True
str(shard.model_card.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -492,7 +538,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -509,7 +555,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
shard.model_card.model_id,
str(shard.model_card.model_id),
revision,
file_progress,
all_start_time,
@@ -519,7 +565,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -543,7 +589,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
shard.model_card.model_id,
str(shard.model_card.model_id),
revision,
file.path,
target_dir,
@@ -557,7 +603,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, shard.model_card.model_id, revision, file_progress, all_start_time
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,7 +3,8 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -18,8 +19,8 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
)
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -30,7 +31,7 @@ async def build_base_shard(model_id: ModelId) -> ShardMetadata:
)
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
@@ -147,7 +148,7 @@ class ResumableShardDownloader(ShardDownloader):
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async def _status_for_model(
model_id: ModelId,
model_id: str,
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Any, cast
import mlx.core as mx
import mlx.nn as nn
@@ -67,27 +67,16 @@ def eval_with_timeout(
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable):
def __init__(self, original_layer: nn.Module):
super().__init__()
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, self["_original_layer"])
def original_layer(self) -> nn.Module:
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -96,56 +85,57 @@ class CustomMlxLayer(nn.Module):
try:
return super().__getattr__(name)
except AttributeError:
original_layer = cast(_LayerCallable, self["_original_layer"])
original_layer = object.__getattribute__(self, "_original_layer")
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def patch_pipeline_first_layer(
pipeline_layer: nn.Module, group: mx.distributed.Group
) -> nn.Module:
cls = type(pipeline_layer)
orig_call = cast(Callable[..., mx.array], cls.__call__)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
return self.original_layer(x, *args, **kwargs)
rank = group.rank()
class PatchedFirstLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(self, x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
class PipelineLastLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
s: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
def patch_pipeline_last_layer(
pipeline_layer: nn.Module, group: mx.distributed.Group
) -> nn.Module:
cls = type(pipeline_layer)
orig_call = cast(Callable[..., mx.array], cls.__call__)
orig_call_sig = signature(orig_call)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
rank = group.rank()
size = group.size()
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
class PatchedLastLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
"cache", None
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
output: mx.array = orig_call(self, x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(output, (rank + 1) % size, group=group)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module:
@@ -160,13 +150,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
# Handle both model.layers and model.h cases
layers: list[_LayerCallable]
layers: list[nn.Module]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers)
layers = cast(list[nn.Module], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"):
layers = cast(list[_LayerCallable], inner_model_instance.h)
layers = cast(list[nn.Module], inner_model_instance.h)
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -191,15 +181,12 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = patch_pipeline_last_layer(
layers[-1],
device_rank,
world_size,
group=group,
group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
@@ -334,7 +321,7 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
@@ -383,6 +370,7 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -445,7 +433,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
@@ -520,17 +508,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
@@ -564,7 +552,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
@@ -598,7 +586,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -611,17 +599,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class GptOssShardingStrategy(TensorParallelShardingStrategy):
@@ -673,7 +661,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
y = self.original_layer(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore

View File

@@ -23,7 +23,6 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
@@ -297,7 +296,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
@@ -321,9 +320,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
return None
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.

View File

@@ -11,15 +11,14 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
class MockLayer(nn.Module):
@@ -117,12 +116,11 @@ def run_gpt_oss_pipeline_device(
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
model=model, tokenizer=tokenizer, task=task, prompt=prompt
):
generated_text += response.text
if response.finish_reason is not None:
@@ -184,11 +182,11 @@ def run_gpt_oss_tensor_parallel_device(
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
model=model, tokenizer=tokenizer, task=task, prompt=prompt
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -10,8 +10,8 @@ import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
PipelineFirstLayer,
PipelineLastLayer,
patch_pipeline_first_layer,
patch_pipeline_last_layer,
patch_pipeline_model,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
@@ -50,8 +50,8 @@ def run_pipeline_device(
group = mx.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
first = patch_pipeline_first_layer(mock, group)
composed = patch_pipeline_last_layer(first, group)
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed])
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer()
group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group)
composed = PipelineLastLayer(first, r=0, s=1, group=group)
first = patch_pipeline_first_layer(mock, group)
composed = patch_pipeline_last_layer(first, group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined]

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
return False
async def download_tokenizer_files(model_id: ModelId) -> Path:
async def download_tokenizer_files(model_id: str) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.normalize()
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: ModelId) -> Path:
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[ModelCard]:
def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = card
families[family] = (card.model_id.short(), card)
return list(families.values())
TEST_MODELS: list[ModelCard] = get_test_models()
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
pytestmark = pytest.mark.slow
@@ -101,13 +101,14 @@ def event_loop():
@pytest.mark.parametrize(
"model_card",
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = model_card.model_id
model_id = str(model_card.model_id)
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
@@ -166,15 +167,16 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
@pytest.mark.parametrize(
"model_card",
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = model_card.model_id
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
@@ -207,18 +209,19 @@ async def test_tokenizer_has_required_attributes(
@pytest.mark.parametrize(
"model_card",
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = model_card.model_id
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
@@ -298,14 +301,16 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
model_card = kimi_models[0]
model_id = model_card.model_id
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
@@ -344,15 +349,17 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
]
if not glm_model_cards:
if not glm_models:
pytest.skip("No GLM models found in MODEL_CARDS")
model_card = glm_model_cards[0]
model_id = model_card.model_id
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import ModelId, NodeId
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress

12
uv.lock generated
View File

@@ -248,7 +248,6 @@ dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -282,7 +281,6 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
@@ -317,16 +315,6 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"