Compare commits

...

10 Commits

Author SHA1 Message Date
Alex Cheema
55b67e2be2 Merge branch 'main' into releases/v1.0.64 2026-01-23 01:40:20 +00:00
Alex Cheema
df240f834d Fix GLM and Kimi tool calling crashes (#1255)
## Motivation

Fixes tool calling crashes with GLM-4.7-Flash and Kimi-K2 models.

Related: #1254

Two distinct issues were causing crashes:
1. **Tool parser crashes** - The upstream GLM47 and Kimi tool parsers
call `.group()` on regex matches without checking for `None`, causing
`AttributeError` when the model outputs malformed tool calls
2. **Chat template crashes** - GLM's chat template expects
`tool_calls[].function.arguments` to be a dict, but OpenAI format
provides it as a JSON string, causing `'str object' has no attribute
'items'`

## Changes

**`src/exo/worker/runner/runner.py`:**
- Add `patch_glm_tokenizer()` - fixed version of mlx_lm's glm47 parser
with None checks
- Fix `patch_kimi_tokenizer()` - add None checks before calling
`.group()` on regex matches
- Add `ValueError` and `AttributeError` to exception handling in
`parse_tool_calls()`

**`src/exo/worker/engines/mlx/utils_mlx.py`:**
- Add `_normalize_tool_calls()` - parses
`tool_calls[].function.arguments` from JSON string to dict for templates
that expect dicts (like GLM-4.7-Flash)

## Why It Works

1. **Parser fixes**: By checking if regex matches are `None` before
calling `.group()`, we can raise a proper `ValueError` instead of
crashing with `AttributeError`

2. **Template fix**: The GLM-4.7-Flash chat template iterates over
arguments with `.items()`:
   ```jinja2
   {% set _args = tc.arguments %}{% for k, v in _args.items() %}
   ```
OpenAI format has `arguments` as a JSON string.
`_normalize_tool_calls()` parses this to a dict before passing to the
template.

## Test Plan

### Manual Testing
- Hardware: Mac with GLM-4.7-Flash-4bit model
- Tested tool calling with GLM model - no longer crashes

### Automated Testing
- Existing tests pass (`uv run pytest`)
- Type checking passes (`uv run basedpyright`)
- Linting passes (`uv run ruff check`)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-23 01:39:59 +00:00
ciaranbor
cd125b3b8c Use icon for image editing models (#1252)
## Motivation

Visual indicator for image editing models

## Changes

Add pencil icon to edit models in model list
2026-01-22 22:37:34 +00:00
Ryuichi Leo Takashige
30cfad9b68 Use custom fork 2026-01-22 22:19:35 +00:00
Alex Cheema
b783a21399 dashboard: add placement filter by clicking topology nodes (#1248)
## Motivation

When selecting a model for placement, users often want to see placements
that utilize specific nodes in their cluster. Currently there's no way
to filter the placement previews to focus on configurations that include
particular machines.

## Changes

- **Backend**: Added `node_ids` query parameter to the
`/placement-previews` API endpoint. When provided, the endpoint filters
the topology to only include the specified nodes before generating
placements using the new `Topology.filter_to_nodes()` method.

- **Topology class**: Added `filter_to_nodes(node_ids)` method that
creates a new topology containing only the specified nodes and edges
between them.

- **App store**: Added `previewNodeFilter` state to track selected
nodes, with methods to toggle/clear the filter. Automatically cleans up
filter when nodes are removed from the cluster and re-fetches previews
when topology changes.

- **TopologyGraph component**: Added click handlers to toggle node
filter selection, hover effects to indicate clickable nodes, and visual
styling (yellow highlight for selected, dimmed for filtered-out nodes).

- **Main page**: Added filter indicator in top-right corner of topology
showing active filter count with a clear button.

## Why It Works

The filtering happens at the backend/placement generation level rather
than just filtering the results. This ensures we see all valid placement
combinations for the selected nodes, not just a subset that happened to
be generated for the full topology.

The visual feedback uses the same rendering approach as the existing
highlight system - state is tracked in Svelte and applied during render,
so it persists across data updates without flickering.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Click a node in topology → should show yellow highlight and filter
indicator
- Click another node → indicator shows "2 nodes", previews update to
show only placements using both
- Hover over nodes → subtle yellow highlight indicates they're clickable
- Click X on filter indicator → clears filter, shows all placements
again
- Disconnect a node while it's in filter → filter auto-removes that node

### Automated Testing
- Existing tests cover the Topology class; the new `filter_to_nodes`
method follows the same patterns

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 22:12:57 +00:00
Alex Cheema
43f12f5d08 Replace LaunchDaemon with dynamic Thunderbolt Bridge loop detection (#1222)
## Motivation

The previous approach installed a LaunchDaemon plist that ran
periodically to disable Thunderbolt Bridge. This required full admin
privileges upfront and ran regardless of whether a problematic loop
existed.

This change replaces that with dynamic detection - only prompting the
user when an actual TB bridge loop with 3+ machines is detected, and
using fine-grained SCPreferences authorization instead of full admin.

## Changes

**Backend (Python):**
- Added `ThunderboltBridgeStatus` model to track bridge enabled/exists
state per node
- Added `node_thunderbolt_bridge` and `thunderbolt_bridge_cycles` fields
to State
- Added `get_thunderbolt_bridge_cycles()` method to Topology class
- **Robust TB bridge detection:**
- Finds bridge network services from `-listnetworkserviceorder` (not
`-listallhardwareports` which can miss bridges)
- Checks each bridge's member interfaces via `ifconfig` to verify it
contains Thunderbolt interfaces
- Handles varying service names (e.g., "TB Bridge", "Thunderbolt
Bridge", "Bridge (bridge0)")
  - Includes `service_name` in status for correct disable commands
  - Added warning logs for all error cases in detection
- Updated `apply.py` to handle the new event type and recompute cycles
on node timeout

**Swift App:**
- New `ThunderboltBridgeService` that monitors for cycles from cluster
state
- Shows NSAlert when a cycle with >2 machines is detected
- Uses `SCPreferencesCreateWithAuthorization` with
`system.services.systemconfiguration.network` right for targeted
permissions
- **Auto-cleanup of legacy LaunchDaemon:** On app startup, checks for
and removes old plist/scripts (non-fatal if user cancels)
- **Periodic local network checking:** Re-checks every 10s so the
warning disappears when user grants permission
- **Fixed ClusterState model:** Updated to decode new granular state
fields (`nodeIdentities`, `nodeMemory`, `nodeSystem`,
`nodeThunderboltBridge`) with computed `nodeProfiles` property for
backwards compatibility
- **Fixed Topology model:** Updated to match actual JSON structure where
`nodes` is an array of strings (not objects) and `connections` is a
nested map (not flat array)
- Cleaned up `NetworkSetupHelper` by removing daemon installation code
(now only handles uninstall)

**Dashboard:**
- Added yellow warning badge on topology when TB bridge cycle detected
- On hover: highlights affected nodes in yellow on the topology graph
- Shows which machines are in the cycle with friendly names
- Provides copy-paste terminal command with the correct service name:
  ```
  sudo networksetup -setnetworkserviceenabled "<service-name>" off
  ```
- Warning appears in all topology views (full, welcome, and minimized
chat sidebar)
- **Debug mode:** Shows "TB:ON" or "TB:OFF" status next to each node in
the topology

## Why It Works

- Cycle detection happens on the backend where we have full topology
information
- Only cycles with 3+ machines are flagged (2-node connections are fine)
- TB bridge detection is robust:
- Uses `-listnetworkserviceorder` to find bridges (works on all machines
tested)
- Verifies bridge membership via `ifconfig` to confirm Thunderbolt
interfaces
  - Handles different service names across machines
- The Swift app reacts to detected cycles and prompts the user once per
cycle
- The dashboard provides visual feedback and actionable instructions
- `SCPreferencesCreateWithAuthorization` provides the minimal
permissions needed to modify network service state
- Legacy LaunchDaemon is automatically cleaned up on first launch with
this version

## Test Plan

### Manual Testing
Here EXO detected a TB bridge cycle:

#### Dashboard:
<img width="1363" height="884" alt="Screenshot 2026-01-21 at 10 07
30 PM"
src="https://github.com/user-attachments/assets/7da9c621-0c91-42c4-898e-4952188a1f61"
/>

#### Hovering the warning:
<img width="359" height="279" alt="Screenshot 2026-01-21 at 16 30 57"
src="https://github.com/user-attachments/assets/05501dcf-3d4a-4704-9f38-257748c05a53"
/>

#### macOS app warning popup:
<img width="270" height="410" alt="Screenshot 2026-01-21 at 16 29 08"
src="https://github.com/user-attachments/assets/45714427-08c3-4fb4-9e61-144925c51adf"
/>

### Which then asks for the user's password:
<img width="263" height="372" alt="Screenshot 2026-01-21 at 16 29 28"
src="https://github.com/user-attachments/assets/7502e591-596d-4128-8cf5-6a12674e27bc"
/>

Which when entered, successfully disables bridge and no longer shows the
warning on dashboard.

#### When it fails it shows the error message:
<img width="263" height="234" alt="Screenshot 2026-01-21 at 14 45 38"
src="https://github.com/user-attachments/assets/2d10b3d5-69d7-46ea-b631-d52d8651ab41"
/>

### Automated Testing
- Type checker: 0 errors (`uv run basedpyright`)
- Linter: All checks passed (`uv run ruff check`)
- Tests: 118 passed (`uv run pytest`)
- Dashboard: Builds successfully (`npm run build`)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 21:53:05 +00:00
ciaranbor
8027d7933f Ciaran/hf token (#1250)
## Motivation

black-forest-labs models require hf auth and signup to download. We
don't handle this gracefully.
https://github.com/exo-explore/exo/issues/1242

## Changes

- Handle auth errors
- Surface error to UI and suggest resolution
- Support using HF_TOKEN env variable for auto
- Hide image functionality behind `EXO_ENABLE_IMAGE_MODELS=true` for now

## Why It Works

Users are presented with actionable feedback when issue occurs

## Test Plan

### Manual Testing

Confirmed loading black-forest-labs model in UI presents the issue in
the UI.
Confirmed both `hf auto login` and setting `HF_TOKEN` resolve the issue
2026-01-22 20:39:53 +00:00
Evan
ac6efa747b add kimi tool parseing
this patches the kimi tokenizer to add tool calling - it can be reverted
once upstream support is added for kimi-k2
2026-01-22 11:49:25 +00:00
Evan
2e3c33db6d implement mlx-lm tool calling
splits up the runners generation chunks into tool calls, tokens and
errors, and writes tool call chunks when the upstream parser detects
them.
2026-01-22 11:49:25 +00:00
rltakashige
fc8e6ad06b Reduce download log spam (#1249)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-22 11:28:36 +00:00
44 changed files with 2545 additions and 793 deletions

View File

@@ -14,6 +14,7 @@ struct ContentView: View {
@EnvironmentObject private var networkStatusService: NetworkStatusService
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
@EnvironmentObject private var updater: SparkleUpdater
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
@State private var focusedNode: NodeViewModel?
@State private var deletingInstanceIDs: Set<String> = []
@State private var showAllNodes = false
@@ -24,6 +25,8 @@ struct ContentView: View {
@State private var bugReportMessage: String?
@State private var uninstallInProgress = false
@State private var pendingNamespace: String = ""
@State private var pendingHFToken: String = ""
@State private var pendingEnableImageModels = false
var body: some View {
VStack(alignment: .leading, spacing: 12) {
@@ -303,6 +306,49 @@ struct ContentView: View {
.disabled(pendingNamespace == controller.customNamespace)
}
}
VStack(alignment: .leading, spacing: 4) {
Text("HuggingFace Token")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
SecureField("optional", text: $pendingHFToken)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingHFToken = controller.hfToken
}
Button("Save & Restart") {
controller.hfToken = pendingHFToken
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingHFToken == controller.hfToken)
}
}
Divider()
HStack {
Toggle(
"Enable Image Models (experimental)", isOn: $pendingEnableImageModels
)
.toggleStyle(.switch)
.font(.caption2)
.onAppear {
pendingEnableImageModels = controller.enableImageModels
}
Spacer()
Button("Save & Restart") {
controller.enableImageModels = pendingEnableImageModels
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingEnableImageModels == controller.enableImageModels)
}
HoverButton(title: "Check for Updates", small: true) {
updater.checkForUpdates()
}
@@ -423,6 +469,44 @@ struct ContentView: View {
}
}
/// Shows TB bridge status for all nodes from exo cluster state
private var clusterThunderboltBridgeView: some View {
let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]
let localNodeId = stateService.localNodeId
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
return VStack(alignment: .leading, spacing: 1) {
if bridgeStatuses.isEmpty {
Text("Cluster TB Bridge: No data")
.font(.caption2)
.foregroundColor(.secondary)
} else {
Text("Cluster TB Bridge Status:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(Array(bridgeStatuses.keys.sorted()), id: \.self) { nodeId in
if let status = bridgeStatuses[nodeId] {
let nodeName =
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
let isLocal = nodeId == localNodeId
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
let statusText =
!status.exists
? "N/A"
: (status.enabled ? "Enabled" : "Disabled")
let color: Color =
!status.exists
? .secondary
: (status.enabled ? .red : .green)
Text("\(prefix) \(statusText)")
.font(.caption2)
.foregroundColor(color)
}
}
}
}
}
private var interfaceIpList: some View {
let statuses = networkStatusService.status.interfaceStatuses
return VStack(alignment: .leading, spacing: 1) {
@@ -465,6 +549,7 @@ struct ContentView: View {
Text(thunderboltStatusText)
.font(.caption2)
.foregroundColor(thunderboltStatusColor)
clusterThunderboltBridgeView
interfaceIpList
rdmaStatusView
sendBugReportButton

View File

@@ -21,6 +21,7 @@ struct EXOApp: App {
@StateObject private var networkStatusService: NetworkStatusService
@StateObject private var localNetworkChecker: LocalNetworkChecker
@StateObject private var updater: SparkleUpdater
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
private let terminationObserver: TerminationObserver
private let ciContext = CIContext(options: nil)
@@ -41,10 +42,13 @@ struct EXOApp: App {
let localNetwork = LocalNetworkChecker()
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
_updater = StateObject(wrappedValue: updater)
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
enableLaunchAtLoginIfNeeded()
NetworkSetupHelper.ensureLaunchDaemonInstalled()
// Check local network access BEFORE launching exo
localNetwork.check()
// Remove old LaunchDaemon components if they exist (from previous versions)
cleanupLegacyNetworkSetup()
// Check local network access periodically (warning disappears when user grants permission)
localNetwork.startPeriodicChecking(interval: 10)
controller.scheduleLaunch(after: 15)
service.startPolling()
networkStatus.startPolling()
@@ -58,6 +62,7 @@ struct EXOApp: App {
.environmentObject(networkStatusService)
.environmentObject(localNetworkChecker)
.environmentObject(updater)
.environmentObject(thunderboltBridgeService)
} label: {
menuBarIcon
}
@@ -130,6 +135,37 @@ struct EXOApp: App {
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
private func cleanupLegacyNetworkSetup() {
guard NetworkSetupHelper.hasInstalledComponents() else { return }
// Dispatch async to ensure app is ready before showing alert
DispatchQueue.main.async {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to configure local network discovery on your device. This requires granting permission once."
alert.alertStyle = .informational
alert.addButton(withTitle: "Continue")
alert.addButton(withTitle: "Later")
let response = alert.runModal()
guard response == .alertFirstButtonReturn else {
Logger().info("User deferred legacy network setup cleanup")
return
}
do {
try NetworkSetupHelper.uninstall()
Logger().info("Cleaned up legacy network setup components")
} catch {
// Non-fatal: user may have cancelled admin prompt or cleanup may have
// partially succeeded. The app will continue normally.
Logger().warning(
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
)
}
}
}
}
/// Helper for managing EXO's launch-at-login registration

View File

@@ -3,6 +3,8 @@ import Combine
import Foundation
private let customNamespaceKey = "EXOCustomNamespace"
private let hfTokenKey = "EXOHFToken"
private let enableImageModelsKey = "EXOEnableImageModels"
@MainActor
final class ExoProcessController: ObservableObject {
@@ -37,6 +39,22 @@ final class ExoProcessController: ObservableObject {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
}
@Published var hfToken: String = {
return UserDefaults.standard.string(forKey: hfTokenKey) ?? ""
}()
{
didSet {
UserDefaults.standard.set(hfToken, forKey: hfTokenKey)
}
}
@Published var enableImageModels: Bool = {
return UserDefaults.standard.bool(forKey: enableImageModelsKey)
}()
{
didSet {
UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)
}
}
private var process: Process?
private var runtimeDirectoryURL: URL?
@@ -191,6 +209,12 @@ final class ExoProcessController: ObservableObject {
var environment = ProcessInfo.processInfo.environment
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
if !hfToken.isEmpty {
environment["HF_TOKEN"] = hfToken
}
if enableImageModels {
environment["EXO_ENABLE_IMAGE_MODELS"] = "true"
}
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {

View File

@@ -5,17 +5,43 @@ import Foundation
struct ClusterState: Decodable {
let instances: [String: ClusterInstance]
let runners: [String: RunnerStatusSummary]
let nodeProfiles: [String: NodeProfile]
let tasks: [String: ClusterTask]
let topology: Topology?
let downloads: [String: [NodeDownloadStatus]]
let thunderboltBridgeCycles: [[String]]
// Granular node state (split from the old nodeProfiles)
let nodeIdentities: [String: NodeIdentity]
let nodeMemory: [String: MemoryInfo]
let nodeSystem: [String: SystemInfo]
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
/// Computed property for backwards compatibility - merges granular state into NodeProfile
var nodeProfiles: [String: NodeProfile] {
var profiles: [String: NodeProfile] = [:]
let allNodeIds = Set(nodeIdentities.keys)
.union(nodeMemory.keys)
.union(nodeSystem.keys)
for nodeId in allNodeIds {
let identity = nodeIdentities[nodeId]
let memory = nodeMemory[nodeId]
let system = nodeSystem[nodeId]
profiles[nodeId] = NodeProfile(
modelId: identity?.modelId,
chipId: identity?.chipId,
friendlyName: identity?.friendlyName,
memory: memory,
system: system
)
}
return profiles
}
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
@@ -24,15 +50,34 @@ struct ClusterState: Decodable {
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
self.thunderboltBridgeCycles =
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
// Granular node state
self.nodeIdentities =
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
?? [:]
self.nodeMemory =
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
self.nodeSystem =
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
self.nodeThunderboltBridge =
try container.decodeIfPresent(
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
) ?? [:]
}
private enum CodingKeys: String, CodingKey {
case instances
case runners
case nodeProfiles
case topology
case tasks
case downloads
case thunderboltBridgeCycles
case nodeIdentities
case nodeMemory
case nodeSystem
case nodeThunderboltBridge
}
}
@@ -102,6 +147,18 @@ struct NodeProfile: Decodable {
let system: SystemInfo?
}
struct NodeIdentity: Decodable {
let modelId: String?
let chipId: String?
let friendlyName: String?
}
struct ThunderboltBridgeStatus: Decodable {
let enabled: Bool
let exists: Bool
let serviceName: String?
}
struct MemoryInfo: Decodable {
let ramTotal: MemoryValue?
let ramAvailable: MemoryValue?
@@ -120,16 +177,51 @@ struct SystemInfo: Decodable {
}
struct Topology: Decodable {
let nodes: [TopologyNode]
let connections: [TopologyConnection]?
/// Node IDs in the topology
let nodes: [String]
/// Flattened list of connections (source -> sink pairs)
let connections: [TopologyConnection]
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
// Connections come as nested map: { source: { sink: [edges] } }
// We flatten to array of (source, sink) pairs
var flatConnections: [TopologyConnection] = []
if let nested = try container.decodeIfPresent(
[String: [String: [AnyCodable]]].self, forKey: .connections
) {
for (source, sinks) in nested {
for sink in sinks.keys {
flatConnections.append(
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
}
}
}
self.connections = flatConnections
}
private enum CodingKeys: String, CodingKey {
case nodes
case connections
}
}
struct TopologyNode: Decodable {
let nodeId: String
let nodeProfile: NodeProfile
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
private struct AnyCodable: Decodable {
init(from decoder: Decoder) throws {
// Just consume the value without storing it
_ = try? decoder.singleValueContainer().decode(Bool.self)
_ = try? decoder.singleValueContainer().decode(Int.self)
_ = try? decoder.singleValueContainer().decode(Double.self)
_ = try? decoder.singleValueContainer().decode(String.self)
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
}
}
struct TopologyConnection: Decodable {
struct TopologyConnection {
let localNodeId: String
let sendBackNodeId: String
}

View File

@@ -55,12 +55,16 @@ struct BugReportService {
let stateData = try await stateResult
let eventsData = try await eventsResult
// Extract cluster TB bridge status from exo state
let clusterTbBridgeStatus = extractClusterTbBridgeStatus(from: stateData)
let reportJSON = makeReportJson(
timestamp: timestamp,
hostName: hostName,
ifconfig: ifconfigText,
debugInfo: debugInfo,
isManual: isManual
isManual: isManual,
clusterTbBridgeStatus: clusterTbBridgeStatus
)
let uploads: [(path: String, data: Data?)] = [
@@ -178,18 +182,19 @@ struct BugReportService {
}
private func readThunderboltBridgeDisabled() -> Bool? {
let result = runCommand([
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased()
if output.contains("enabled") {
return false
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
// No bridge containing Thunderbolt interfaces exists
return nil
}
if output.contains("disabled") {
return true
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
else {
return nil
}
return nil
// Return true if disabled, false if enabled
return !isEnabled
}
private func readInterfaces() -> [DebugInfo.InterfaceStatus] {
@@ -268,11 +273,12 @@ struct BugReportService {
hostName: String,
ifconfig: String,
debugInfo: DebugInfo,
isManual: Bool
isManual: Bool,
clusterTbBridgeStatus: [[String: Any]]?
) -> Data? {
let system = readSystemMetadata()
let exo = readExoMetadata()
let payload: [String: Any] = [
var payload: [String: Any] = [
"timestamp": timestamp,
"host": hostName,
"ifconfig": ifconfig,
@@ -282,9 +288,38 @@ struct BugReportService {
"exo_commit": exo.commit as Any,
"report_type": isManual ? "manual" : "automated",
]
if let tbStatus = clusterTbBridgeStatus {
payload["cluster_thunderbolt_bridge"] = tbStatus
}
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
}
/// Extracts cluster-wide Thunderbolt Bridge status from exo state JSON
private func extractClusterTbBridgeStatus(from stateData: Data?) -> [[String: Any]]? {
guard let data = stateData,
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let nodeThunderboltBridge = json["node_thunderbolt_bridge"] as? [String: [String: Any]]
else {
return nil
}
var result: [[String: Any]] = []
for (nodeId, status) in nodeThunderboltBridge {
var entry: [String: Any] = ["node_id": nodeId]
if let enabled = status["enabled"] as? Bool {
entry["enabled"] = enabled
}
if let exists = status["exists"] as? Bool {
entry["exists"] = exists
}
if let serviceName = status["service_name"] as? String {
entry["service_name"] = serviceName
}
result.append(entry)
}
return result.isEmpty ? nil : result
}
private func readSystemMetadata() -> [String: Any] {
let hostname = safeRunCommand(["/bin/hostname"])
let computerName = safeRunCommand(["/usr/sbin/scutil", "--get", "ComputerName"])

View File

@@ -41,6 +41,7 @@ final class LocalNetworkChecker: ObservableObject {
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
private var periodicTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
@@ -48,10 +49,39 @@ final class LocalNetworkChecker: ObservableObject {
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
/// Checks if local network access is working (one-time check).
func check() {
performCheck()
}
/// Starts periodic checking of local network access.
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
func startPeriodicChecking(interval: TimeInterval = 10) {
stopPeriodicChecking()
// Do an immediate check first
performCheck()
// Then schedule periodic checks
periodicTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
guard !Task.isCancelled else { break }
self?.performCheck()
}
}
}
/// Stops periodic checking.
func stopPeriodicChecking() {
periodicTask?.cancel()
periodicTask = nil
}
private func performCheck() {
checkTask?.cancel()
status = .checking
// Only show "checking" status on first check to avoid UI flicker
if status == .unknown {
status = .checking
}
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
@@ -60,12 +90,15 @@ final class LocalNetworkChecker: ObservableObject {
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
// Only log on state changes or first check to reduce noise
if isFirstCheck || result != self.status {
Self.logger.info("Local network check: \(result.displayText)")
}
}
}
@@ -141,6 +174,7 @@ final class LocalNetworkChecker: ObservableObject {
}
func stop() {
stopPeriodicChecking()
checkTask?.cancel()
checkTask = nil
connection?.cancel()

View File

@@ -7,48 +7,10 @@ enum NetworkSetupHelper {
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
// Legacy script path from older versions
private static let legacyScriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
do {
if daemonAlreadyInstalled() {
return
}
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
@@ -63,8 +25,9 @@ enum NetworkSetupHelper {
static func hasInstalledComponents() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
return scriptExists || plistExists
return scriptExists || legacyScriptExists || plistExists
}
private static func makeUninstallScript() -> String {
@@ -73,6 +36,7 @@ enum NetworkSetupHelper {
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
@@ -83,8 +47,9 @@ enum NetworkSetupHelper {
# Remove LaunchDaemon plist
rm -f "$PLIST_DEST"
# Remove the script and parent directory if empty
# Remove the script (current and legacy paths) and parent directory if empty
rm -f "$SCRIPT_DEST"
rm -f "$LEGACY_SCRIPT_DEST"
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
# Remove log files
@@ -98,99 +63,42 @@ enum NetworkSetupHelper {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable Thunderbolt Bridge if it exists
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
# We find it dynamically by looking for bridges containing Thunderbolt interfaces
find_and_enable_thunderbolt_bridge() {
# Get Thunderbolt interface devices from hardware ports
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
')
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
if echo "$members" | grep -qx "$tb_dev"; then
# Find the service name for this bridge device
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
')
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
fi
fi
done
done
}
find_and_enable_thunderbolt_bridge
echo "EXO network components removed successfully"
"""
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
guard
let interval = plist["StartInterval"] as? Int,
interval == requiredStartInterval
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
}
private static func installLaunchDaemon() async throws {
let installerScript = makeInstallerScript()
try runShellAsAdmin(installerScript)
}
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript =
script

View File

@@ -153,22 +153,18 @@ private struct NetworkStatusFetcher {
}
private func readThunderboltBridgeState() -> ThunderboltState? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
guard result.exitCode == 0 else {
let lower = result.output.lowercased() + result.error.lowercased()
if lower.contains("not a recognized network service") {
return .deleted
}
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
// No bridge containing Thunderbolt interfaces exists
return .deleted
}
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
else {
return nil
}
let output = result.output.lowercased()
if output.contains("enabled") {
return .enabled
}
if output.contains("disabled") {
return .disabled
}
return nil
return isEnabled ? .enabled : .disabled
}
private func readBridgeInactive() -> Bool? {

View File

@@ -0,0 +1,194 @@
import Foundation
import os.log
/// Utility for dynamically detecting Thunderbolt Bridge network services.
/// This mirrors the Python logic in info_gatherer.py - we never assume the service
/// is named "Thunderbolt Bridge", instead we find bridges containing Thunderbolt interfaces.
enum ThunderboltBridgeDetector {
private static let logger = Logger(
subsystem: "io.exo.EXO", category: "ThunderboltBridgeDetector")
struct CommandResult {
let exitCode: Int32
let output: String
let error: String
}
/// Find the network service name of a bridge containing Thunderbolt interfaces.
/// Returns nil if no such bridge exists.
static func findThunderboltBridgeServiceName() -> String? {
// 1. Get all Thunderbolt interface devices (e.g., en2, en3)
guard let thunderboltDevices = getThunderboltDevices(), !thunderboltDevices.isEmpty else {
logger.debug("No Thunderbolt devices found")
return nil
}
logger.debug("Found Thunderbolt devices: \(thunderboltDevices.joined(separator: ", "))")
// 2. Get bridge services from network service order
guard let bridgeServices = getBridgeServices(), !bridgeServices.isEmpty else {
logger.debug("No bridge services found")
return nil
}
logger.debug("Found bridge services: \(bridgeServices.keys.joined(separator: ", "))")
// 3. Find a bridge that contains Thunderbolt interfaces
for (bridgeDevice, serviceName) in bridgeServices {
let members = getBridgeMembers(bridgeDevice: bridgeDevice)
logger.debug(
"Bridge \(bridgeDevice) (\(serviceName)) has members: \(members.joined(separator: ", "))"
)
// Check if any Thunderbolt device is a member of this bridge
if !members.isDisjoint(with: thunderboltDevices) {
logger.info(
"Found Thunderbolt Bridge service: '\(serviceName)' (device: \(bridgeDevice))")
return serviceName
}
}
logger.debug("No bridge found containing Thunderbolt interfaces")
return nil
}
/// Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
private static func getThunderboltDevices() -> Set<String>? {
let result = runCommand(["networksetup", "-listallhardwareports"])
guard result.exitCode == 0 else {
logger.warning("networksetup -listallhardwareports failed: \(result.error)")
return nil
}
var thunderboltDevices: Set<String> = []
var currentPort: String?
for line in result.output.components(separatedBy: .newlines) {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("Hardware Port:") {
currentPort = String(trimmed.dropFirst("Hardware Port:".count)).trimmingCharacters(
in: .whitespaces)
} else if trimmed.hasPrefix("Device:"), let port = currentPort {
let device = String(trimmed.dropFirst("Device:".count)).trimmingCharacters(
in: .whitespaces)
if port.lowercased().contains("thunderbolt") {
thunderboltDevices.insert(device)
}
currentPort = nil
}
}
return thunderboltDevices
}
/// Get mapping of bridge device -> service name from network service order.
private static func getBridgeServices() -> [String: String]? {
let result = runCommand(["networksetup", "-listnetworkserviceorder"])
guard result.exitCode == 0 else {
logger.warning("networksetup -listnetworkserviceorder failed: \(result.error)")
return nil
}
// Parse service order to find bridge devices and their service names
// Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
var bridgeServices: [String: String] = [:]
var currentService: String?
for line in result.output.components(separatedBy: .newlines) {
let trimmed = line.trimmingCharacters(in: .whitespaces)
// Match "(N) Service Name" or "(*) Service Name" (disabled)
// but NOT "(Hardware Port: ...)" lines
if trimmed.hasPrefix("("), trimmed.contains(")"),
!trimmed.hasPrefix("(Hardware Port:")
{
if let parenEnd = trimmed.firstIndex(of: ")") {
let afterParen = trimmed.index(after: parenEnd)
if afterParen < trimmed.endIndex {
currentService =
String(trimmed[afterParen...])
.trimmingCharacters(in: .whitespaces)
}
}
}
// Match "(Hardware Port: ..., Device: bridgeX)"
else if let service = currentService, trimmed.contains("Device: bridge") {
// Extract device name from "..., Device: bridge0)"
if let deviceRange = trimmed.range(of: "Device: ") {
let afterDevice = trimmed[deviceRange.upperBound...]
if let parenIndex = afterDevice.firstIndex(of: ")") {
let device = String(afterDevice[..<parenIndex])
bridgeServices[device] = service
}
}
}
}
return bridgeServices
}
/// Get member interfaces of a bridge device via ifconfig.
private static func getBridgeMembers(bridgeDevice: String) -> Set<String> {
let result = runCommand(["ifconfig", bridgeDevice])
guard result.exitCode == 0 else {
logger.debug("ifconfig \(bridgeDevice) failed")
return []
}
var members: Set<String> = []
for line in result.output.components(separatedBy: .newlines) {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("member:") {
let parts = trimmed.split(separator: " ")
if parts.count > 1 {
members.insert(String(parts[1]))
}
}
}
return members
}
/// Check if a network service is enabled.
static func isServiceEnabled(serviceName: String) -> Bool? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", serviceName])
guard result.exitCode == 0 else {
logger.warning("Failed to check if '\(serviceName)' is enabled: \(result.error)")
return nil
}
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
if output.contains("enabled") {
return true
}
if output.contains("disabled") {
return false
}
return nil
}
private static func runCommand(_ arguments: [String]) -> CommandResult {
let process = Process()
process.launchPath = "/usr/bin/env"
process.arguments = arguments
let stdout = Pipe()
let stderr = Pipe()
process.standardOutput = stdout
process.standardError = stderr
do {
try process.run()
} catch {
return CommandResult(exitCode: -1, output: "", error: error.localizedDescription)
}
process.waitUntilExit()
let outputData = stdout.fileHandleForReading.readDataToEndOfFile()
let errorData = stderr.fileHandleForReading.readDataToEndOfFile()
return CommandResult(
exitCode: process.terminationStatus,
output: String(decoding: outputData, as: UTF8.self),
error: String(decoding: errorData, as: UTF8.self)
)
}
}

View File

@@ -0,0 +1,258 @@
import AppKit
import Combine
import Foundation
import Security
import SystemConfiguration
import os.log
@MainActor
final class ThunderboltBridgeService: ObservableObject {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
@Published private(set) var detectedCycle: [String]?
@Published private(set) var hasPromptedForCurrentCycle = false
@Published private(set) var lastError: String?
private weak var clusterStateService: ClusterStateService?
private var cancellables = Set<AnyCancellable>()
private var previousCycleSignature: String?
init(clusterStateService: ClusterStateService) {
self.clusterStateService = clusterStateService
setupObserver()
}
private func setupObserver() {
guard let service = clusterStateService else { return }
service.$latestSnapshot
.compactMap { $0 }
.sink { [weak self] snapshot in
self?.checkForCycles(snapshot: snapshot)
}
.store(in: &cancellables)
}
private func checkForCycles(snapshot: ClusterState) {
let cycles = snapshot.thunderboltBridgeCycles
// Only consider cycles with more than 2 nodes
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
// No problematic cycles detected, reset state
if detectedCycle != nil {
detectedCycle = nil
previousCycleSignature = nil
hasPromptedForCurrentCycle = false
}
return
}
// Create a signature for this cycle to detect if it changed
let cycleSignature = firstCycle.sorted().joined(separator: ",")
// If this is a new/different cycle, reset the prompt state
if cycleSignature != previousCycleSignature {
previousCycleSignature = cycleSignature
hasPromptedForCurrentCycle = false
}
detectedCycle = firstCycle
// Only prompt once per cycle
if !hasPromptedForCurrentCycle {
showDisableBridgePrompt(nodeIds: firstCycle)
}
}
private func showDisableBridgePrompt(nodeIds: [String]) {
hasPromptedForCurrentCycle = true
// Get friendly names for the nodes if available
let nodeNames = nodeIds.map { nodeId -> String in
if let snapshot = clusterStateService?.latestSnapshot,
let profile = snapshot.nodeProfiles[nodeId],
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
{
return friendlyName
}
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
}
let machineNames = nodeNames.joined(separator: ", ")
let alert = NSAlert()
alert.messageText = "Thunderbolt Bridge Loop Detected"
alert.informativeText = """
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Disable Bridge")
alert.addButton(withTitle: "Not Now")
let response = alert.runModal()
if response == .alertFirstButtonReturn {
Task {
await disableThunderboltBridge()
}
}
}
func disableThunderboltBridge() async {
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
lastError = nil
do {
try await disableThunderboltBridgeWithSCPreferences()
Self.logger.info("Successfully disabled Thunderbolt Bridge")
} catch {
Self.logger.error(
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
)
lastError = error.localizedDescription
showErrorAlert(message: error.localizedDescription)
}
}
private func disableThunderboltBridgeWithSCPreferences() async throws {
// 1. Create authorization reference
var authRef: AuthorizationRef?
var status = AuthorizationCreate(nil, nil, [], &authRef)
guard status == errAuthorizationSuccess, let authRef = authRef else {
throw ThunderboltBridgeError.authorizationFailed
}
defer { AuthorizationFree(authRef, [.destroyRights]) }
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled
}
throw ThunderboltBridgeError.authorizationDenied
}
// 3. Create SCPreferences with authorization
guard
let prefs = SCPreferencesCreateWithAuthorization(
kCFAllocatorDefault,
"EXO" as CFString,
nil,
authRef
)
else {
throw ThunderboltBridgeError.preferencesCreationFailed
}
// 4. Lock, modify, commit
guard SCPreferencesLock(prefs, true) else {
throw ThunderboltBridgeError.lockFailed
}
defer {
SCPreferencesUnlock(prefs)
}
// 5. Find the Thunderbolt Bridge service dynamically (don't assume the name)
guard let targetServiceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName()
else {
throw ThunderboltBridgeError.serviceNotFound
}
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
throw ThunderboltBridgeError.servicesNotFound
}
var found = false
for service in allServices {
if let name = SCNetworkServiceGetName(service) as String?,
name == targetServiceName
{
guard SCNetworkServiceSetEnabled(service, false) else {
throw ThunderboltBridgeError.disableFailed
}
found = true
Self.logger.info(
"Found and disabled Thunderbolt Bridge service: '\(targetServiceName)'")
break
}
}
if !found {
throw ThunderboltBridgeError.serviceNotFound
}
// 6. Commit and apply
guard SCPreferencesCommitChanges(prefs) else {
throw ThunderboltBridgeError.commitFailed
}
guard SCPreferencesApplyChanges(prefs) else {
throw ThunderboltBridgeError.applyFailed
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Failed to Disable Thunderbolt Bridge"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
}
enum ThunderboltBridgeError: LocalizedError {
case authorizationFailed
case authorizationCanceled
case authorizationDenied
case preferencesCreationFailed
case lockFailed
case servicesNotFound
case serviceNotFound
case disableFailed
case commitFailed
case applyFailed
var errorDescription: String? {
switch self {
case .authorizationFailed:
return "Failed to create authorization"
case .authorizationCanceled:
return "Authorization was canceled by user"
case .authorizationDenied:
return "Authorization was denied"
case .preferencesCreationFailed:
return "Failed to access network preferences"
case .lockFailed:
return "Failed to lock network preferences for modification"
case .servicesNotFound:
return "Could not retrieve network services"
case .serviceNotFound:
return "Thunderbolt Bridge service not found"
case .disableFailed:
return "Failed to disable Thunderbolt Bridge service"
case .commitFailed:
return "Failed to save network configuration changes"
case .applyFailed:
return "Failed to apply network configuration changes"
}
}
}

View File

@@ -86,7 +86,7 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let topologyNodeIds = Set(topology?.nodes ?? [])
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
@@ -95,8 +95,8 @@ extension ClusterState {
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
var orderedNodes: [NodeViewModel] = []
if let topologyNodes = topology?.nodes {
for topoNode in topologyNodes {
if let viewModel = nodesById[topoNode.nodeId] {
for nodeId in topologyNodes {
if let viewModel = nodesById[nodeId] {
orderedNodes.append(viewModel)
}
}
@@ -116,7 +116,7 @@ extension ClusterState {
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections?.compactMap { connection in
topology?.connections.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }

View File

@@ -865,6 +865,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -904,6 +905,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,6 +1522,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1529,6 +1532,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,6 +1945,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2648,6 +2653,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2690,6 +2696,7 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2862,6 +2869,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3006,6 +3014,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3027,6 +3036,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -5,22 +5,32 @@
topologyData,
isTopologyMinimized,
debugMode,
nodeThunderboltBridge,
type NodeInfo,
} from "$lib/stores/app.svelte";
interface Props {
class?: string;
highlightedNodes?: Set<string>;
filteredNodes?: Set<string>;
onNodeClick?: (nodeId: string) => void;
}
let { class: className = "", highlightedNodes = new Set() }: Props = $props();
let {
class: className = "",
highlightedNodes = new Set(),
filteredNodes = new Set(),
onNodeClick,
}: Props = $props();
let svgContainer: SVGSVGElement | undefined = $state();
let resizeObserver: ResizeObserver | undefined;
let hoveredNodeId = $state<string | null>(null);
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -522,10 +532,72 @@
}
}
let iconBaseWidth = nodeRadius * 1.2;
let iconBaseHeight = nodeRadius * 1.0;
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
const modelLower = modelId.toLowerCase();
// Check node states for styling
const isHighlighted = highlightedNodes.has(nodeInfo.id);
const isInFilter =
filteredNodes.size > 0 && filteredNodes.has(nodeInfo.id);
const isFilteredOut =
filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id);
const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter;
// Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out
const wireColor = isInFilter
? "rgba(255,215,0,1)" // Bright yellow for filter selection
: isHovered
? "rgba(255,215,0,0.7)" // Subtle yellow for hover
: isHighlighted
? "rgba(255,215,0,0.9)" // Yellow for instance highlight
: isFilteredOut
? "rgba(140,140,140,0.6)" // Grey for filtered out
: "rgba(179,179,179,0.8)"; // Default
const wireColorBright = "rgba(255,255,255,0.9)";
const fillColor = isInFilter
? "rgba(255,215,0,0.25)"
: isHovered
? "rgba(255,215,0,0.12)"
: isHighlighted
? "rgba(255,215,0,0.15)"
: "rgba(255,215,0,0.08)";
const strokeWidth = isInFilter
? 3
: isHovered
? 2
: isHighlighted
? 2.5
: 1.5;
const screenFill = "rgba(0,20,40,0.9)";
const glowColor = "rgba(255,215,0,0.3)";
const nodeG = nodesGroup
.append("g")
.attr("class", "graph-node")
.style("cursor", "pointer");
.style("cursor", onNodeClick ? "pointer" : "default")
.style("opacity", isFilteredOut ? 0.5 : 1);
// Add click and hover handlers - hover just updates state, styling is applied during render
nodeG
.on("click", (event: MouseEvent) => {
if (onNodeClick) {
event.stopPropagation();
onNodeClick(nodeInfo.id);
}
})
.on("mouseenter", () => {
if (onNodeClick) {
hoveredNodeId = nodeInfo.id;
}
})
.on("mouseleave", () => {
if (hoveredNodeId === nodeInfo.id) {
hoveredNodeId = null;
}
});
// Add tooltip
nodeG
@@ -534,27 +606,6 @@
`${friendlyName}\nID: ${nodeInfo.id.slice(-8)}\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`,
);
let iconBaseWidth = nodeRadius * 1.2;
let iconBaseHeight = nodeRadius * 1.0;
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
const modelLower = modelId.toLowerCase();
// Check if this node should be highlighted (from hovered instance)
const isHighlighted = highlightedNodes.has(nodeInfo.id);
// Holographic wireframe colors - yellow border when highlighted
const wireColor = isHighlighted
? "rgba(255,215,0,0.9)"
: "rgba(179,179,179,0.8)";
const wireColorBright = "rgba(255,255,255,0.9)";
const fillColor = isHighlighted
? "rgba(255,215,0,0.15)"
: "rgba(255,215,0,0.08)";
const strokeWidth = isHighlighted ? 2.5 : 1.5;
const screenFill = "rgba(0,20,40,0.9)";
const glowColor = "rgba(255,215,0,0.3)";
if (modelLower === "mac studio") {
// Mac Studio - classic cube with memory fill
iconBaseWidth = nodeRadius * 1.25;
@@ -579,6 +630,7 @@
// Main body (uniform color)
nodeG
.append("rect")
.attr("class", "node-outline")
.attr("x", x)
.attr("y", y)
.attr("width", iconBaseWidth)
@@ -661,6 +713,7 @@
// Main body (uniform color)
nodeG
.append("rect")
.attr("class", "node-outline")
.attr("x", x)
.attr("y", y)
.attr("width", iconBaseWidth)
@@ -738,6 +791,7 @@
// Screen outer frame
nodeG
.append("rect")
.attr("class", "node-outline")
.attr("x", screenX)
.attr("y", y)
.attr("width", screenWidth)
@@ -846,6 +900,7 @@
// Main shape
nodeG
.append("polygon")
.attr("class", "node-outline")
.attr("points", hexPoints)
.attr("fill", fillColor)
.attr("stroke", wireColor)
@@ -1064,11 +1119,41 @@
.attr("fill", "rgba(179,179,179,0.7)")
.text(` (${ramUsagePercent.toFixed(0)}%)`);
}
// Debug mode: Show TB bridge status
if (debugEnabled) {
const tbStatus = tbBridgeData[nodeInfo.id];
if (tbStatus) {
const tbY =
nodeInfo.y +
iconBaseHeight / 2 +
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
const tbFontSize = showFullLabels ? 9 : 7;
const tbColor = tbStatus.enabled
? "rgba(234,179,8,0.9)"
: "rgba(100,100,100,0.7)";
const tbText = tbStatus.enabled ? "TB:ON" : "TB:OFF";
nodeG
.append("text")
.attr("x", nodeInfo.x)
.attr("y", tbY)
.attr("text-anchor", "middle")
.attr("fill", tbColor)
.attr("font-size", tbFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(tbText);
}
}
});
}
$effect(() => {
if (data) {
// Track all reactive dependencies that affect rendering
const _data = data;
const _hoveredNodeId = hoveredNodeId;
const _filteredNodes = filteredNodes;
const _highlightedNodes = highlightedNodes;
if (_data) {
renderGraph();
}
});
@@ -1091,12 +1176,8 @@
<style>
:global(.graph-node) {
transition:
transform 0.2s ease,
opacity 0.2s ease;
}
:global(.graph-node:hover) {
filter: brightness(1.1);
/* Only transition opacity for filtered-out nodes, no transition on hover stroke changes */
transition: opacity 0.2s ease;
}
:global(.graph-link) {
stroke: var(--exo-light-gray, #b3b3b3);

View File

@@ -190,6 +190,13 @@ interface RawStateResponse {
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
// Thunderbolt bridge status per node
nodeThunderboltBridge?: Record<
string,
{ enabled: boolean; exists: boolean; serviceName?: string | null }
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
}
export interface MessageAttachment {
@@ -419,7 +426,15 @@ class AppStore {
placementPreviews = $state<PlacementPreview[]>([]);
selectedPreviewModelId = $state<string | null>(null);
isLoadingPreviews = $state(false);
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderboltBridge = $state<
Record<
string,
{ enabled: boolean; exists: boolean; serviceName?: string | null }
>
>({});
// UI state
isTopologyMinimized = $state(false);
@@ -439,6 +454,7 @@ class AppStore {
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private previousNodeIds: Set<string> = new Set();
constructor() {
if (browser) {
@@ -997,6 +1013,8 @@ class AppStore {
nodeSystem: data.nodeSystem,
nodeNetwork: data.nodeNetwork,
});
// Handle topology changes for preview filter
this.handleTopologyChange();
}
if (data.instances) {
this.instances = data.instances;
@@ -1008,6 +1026,10 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
// Thunderbolt bridge cycles
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
// Thunderbolt bridge status per node
this.nodeThunderboltBridge = data.nodeThunderboltBridge ?? {};
this.lastUpdate = Date.now();
} catch (error) {
console.error("Error fetching state:", error);
@@ -1023,9 +1045,14 @@ class AppStore {
this.selectedPreviewModelId = modelId;
try {
const response = await fetch(
`/instance/previews?model_id=${encodeURIComponent(modelId)}`,
);
let url = `/instance/previews?model_id=${encodeURIComponent(modelId)}`;
// Add node filter if active
if (this.previewNodeFilter.size > 0) {
for (const nodeId of this.previewNodeFilter) {
url += `&node_ids=${encodeURIComponent(nodeId)}`;
}
}
const response = await fetch(url);
if (!response.ok) {
throw new Error(
`Failed to fetch placement previews: ${response.status}`,
@@ -1075,6 +1102,71 @@ class AppStore {
}
}
/**
* Toggle a node in the preview filter and re-fetch placements
*/
togglePreviewNodeFilter(nodeId: string) {
const next = new Set(this.previewNodeFilter);
if (next.has(nodeId)) {
next.delete(nodeId);
} else {
next.add(nodeId);
}
this.previewNodeFilter = next;
// Re-fetch with new filter if we have a selected model
if (this.selectedPreviewModelId) {
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
}
}
/**
* Clear the preview node filter and re-fetch placements
*/
clearPreviewNodeFilter() {
this.previewNodeFilter = new Set();
// Re-fetch with no filter if we have a selected model
if (this.selectedPreviewModelId) {
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
}
}
/**
* Handle topology changes - clean up filter and re-fetch if needed
*/
private handleTopologyChange() {
if (!this.topologyData) return;
const currentNodeIds = new Set(Object.keys(this.topologyData.nodes));
// Check if nodes have changed
const nodesAdded = [...currentNodeIds].some(
(id) => !this.previousNodeIds.has(id),
);
const nodesRemoved = [...this.previousNodeIds].some(
(id) => !currentNodeIds.has(id),
);
if (nodesAdded || nodesRemoved) {
// Clean up filter - remove any nodes that no longer exist
if (this.previewNodeFilter.size > 0) {
const validFilterNodes = new Set(
[...this.previewNodeFilter].filter((id) => currentNodeIds.has(id)),
);
if (validFilterNodes.size !== this.previewNodeFilter.size) {
this.previewNodeFilter = validFilterNodes;
}
}
// Re-fetch previews if we have a selected model (topology changed)
if (this.selectedPreviewModelId) {
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
}
}
// Update tracked node IDs for next comparison
this.previousNodeIds = currentNodeIds;
}
/**
* Starts a chat conversation - triggers the topology minimization animation
* Creates a new conversation if none is active
@@ -2098,6 +2190,10 @@ export const setSelectedChatModel = (modelId: string) =>
appStore.setSelectedModel(modelId);
export const selectPreviewModel = (modelId: string | null) =>
appStore.selectPreviewModel(modelId);
export const togglePreviewNodeFilter = (nodeId: string) =>
appStore.togglePreviewNodeFilter(nodeId);
export const clearPreviewNodeFilter = () => appStore.clearPreviewNodeFilter();
export const previewNodeFilter = () => appStore.previewNodeFilter;
export const deleteMessage = (messageId: string) =>
appStore.deleteMessage(messageId);
export const editMessage = (messageId: string, newContent: string) =>
@@ -2134,6 +2230,10 @@ export const setChatSidebarVisible = (visible: boolean) =>
appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();
// Thunderbolt bridge status
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
// Image generation params
export const imageGenerationParams = () => appStore.getImageGenerationParams();
export const setImageGenerationParams = (

View File

@@ -19,6 +19,9 @@
selectedPreviewModelId,
isLoadingPreviews,
selectPreviewModel,
togglePreviewNodeFilter,
clearPreviewNodeFilter,
previewNodeFilter,
createConversation,
setSelectedChatModel,
selectedChatModel,
@@ -28,6 +31,8 @@
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
thunderboltBridgeCycles,
nodeThunderboltBridge,
type DownloadProgress,
type PlacementPreview,
} from "$lib/stores/app.svelte";
@@ -49,6 +54,41 @@
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const sidebarVisible = $derived(chatSidebarVisible());
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
const tbBridgeData = $derived(nodeThunderboltBridge());
const nodeFilter = $derived(previewNodeFilter());
// Helper to get friendly node name from node ID
function getNodeName(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8) + "...";
}
// Helper to get the thunderbolt bridge service name from a cycle
function getTbBridgeServiceName(cycle: string[]): string {
// Try to find service name from any node in the cycle
for (const nodeId of cycle) {
const nodeData = tbBridgeData?.[nodeId];
if (nodeData?.serviceName) {
return nodeData.serviceName;
}
}
return "Thunderbolt Bridge"; // Fallback if no service name found
}
// Copy to clipboard state and function
let copiedCommand = $state(false);
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text);
copiedCommand = true;
setTimeout(() => {
copiedCommand = false;
}, 2000);
} catch (err) {
console.error("Failed to copy:", err);
}
}
let mounted = $state(false);
@@ -90,6 +130,15 @@
model.tasks.includes("ImageToImage")
);
}
// Helper to check if a model supports image editing
function modelSupportsImageEditing(modelId: string): boolean {
const model = models.find(
(m) => m.id === modelId || m.hugging_face_id === modelId,
);
if (!model?.tasks) return false;
return model.tasks.includes("ImageToImage");
}
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
@@ -181,6 +230,9 @@
// Preview card hover state for highlighting nodes in topology
let hoveredPreviewNodes = $state<Set<string>>(new Set());
// Computed: Check if filter is active (from store)
const isFilterActive = $derived(() => nodeFilter.size > 0);
// Helper to unwrap tagged instance for hover highlighting
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
if (!instanceWrapped || typeof instanceWrapped !== "object")
@@ -732,6 +784,8 @@
instanceWrapped: unknown,
): {
isDownloading: boolean;
isFailed: boolean;
errorMessage: string | null;
progress: DownloadProgress | null;
statusText: string;
perNode: Array<{
@@ -743,6 +797,8 @@
if (!downloadsData || Object.keys(downloadsData).length === 0) {
return {
isDownloading: false,
isFailed: false,
errorMessage: null,
progress: null,
statusText: "RUNNING",
perNode: [],
@@ -754,6 +810,8 @@
if (!instance || typeof instance !== "object") {
return {
isDownloading: false,
isFailed: false,
errorMessage: null,
progress: null,
statusText: "PREPARING",
perNode: [],
@@ -809,6 +867,26 @@
downloadKind
] as Record<string, unknown>;
// Handle DownloadFailed - return immediately with error info
if (downloadKind === "DownloadFailed") {
const downloadModelId = extractModelIdFromDownload(downloadPayload);
if (
instanceModelId &&
downloadModelId &&
downloadModelId === instanceModelId
) {
return {
isDownloading: false,
isFailed: true,
errorMessage:
(downloadPayload.errorMessage as string) || "Download failed",
progress: null,
statusText: "FAILED",
perNode: [],
};
}
}
if (downloadKind !== "DownloadOngoing") continue;
if (!downloadPayload) continue;
@@ -844,6 +922,8 @@
const statusInfo = deriveInstanceStatus(instanceWrapped);
return {
isDownloading: false,
isFailed: statusInfo.statusText === "FAILED",
errorMessage: null,
progress: null,
statusText: statusInfo.statusText,
perNode: [],
@@ -856,6 +936,8 @@
return {
isDownloading: true,
isFailed: false,
errorMessage: null,
progress: {
totalBytes,
downloadedBytes,
@@ -1451,7 +1533,7 @@
// Get ALL filtered previews based on current settings (matching minimum nodes)
// Note: previewsData already contains previews for the selected model (fetched via API)
// We filter by sharding/instance type and min nodes, returning ALL eligible previews
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
const filteredPreviews = $derived(() => {
if (!selectedModelId || previewsData.length === 0) return [];
@@ -1584,7 +1666,86 @@
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div class="absolute top-4 left-4 group" role="alert">
<div
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
>
<svg
class="w-5 h-5 text-yellow-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<span class="text-sm font-mono text-yellow-200">
THUNDERBOLT BRIDGE CYCLE DETECTED
</span>
</div>
<!-- Tooltip on hover -->
<div
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
>
<p class="text-xs text-white/80 mb-2">
A network routing cycle was detected between nodes connected
via Thunderbolt Bridge. This can cause connectivity issues.
</p>
<p class="text-xs text-white/60 mb-2">
<span class="text-yellow-300">Affected nodes:</span>
{cycle.map(getNodeName).join(" → ")}
</p>
<p class="text-xs text-white/60 mb-1">
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
Bridge on one of the affected nodes:
</p>
<button
type="button"
onclick={() => copyToClipboard(disableCmd)}
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
title="Click to copy"
>
<span class="flex-1">{disableCmd}</span>
<svg
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
{#if copiedCommand}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M5 13l4 4L19 7"
/>
{:else}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
/>
{/if}
</svg>
</button>
</div>
</div>
{/if}
<!-- Exit topology-only mode button -->
<button
type="button"
@@ -1624,7 +1785,111 @@
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div class="absolute top-4 left-4 group" role="alert">
<div
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
>
<svg
class="w-5 h-5 text-yellow-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<span class="text-sm font-mono text-yellow-200">
THUNDERBOLT BRIDGE CYCLE DETECTED
</span>
</div>
<!-- Tooltip on hover -->
<div
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
>
<p class="text-xs text-white/80 mb-2">
A network routing cycle was detected between nodes connected
via Thunderbolt Bridge. This can cause connectivity issues.
</p>
<p class="text-xs text-white/60 mb-2">
<span class="text-yellow-300">Affected nodes:</span>
{cycle.map(getNodeName).join(" → ")}
</p>
<p class="text-xs text-white/60 mb-1">
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
Bridge on one of the affected nodes:
</p>
<button
type="button"
onclick={() => copyToClipboard(disableCmd)}
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
title="Click to copy"
>
<span class="flex-1">{disableCmd}</span>
<svg
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
{#if copiedCommand}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M5 13l4 4L19 7"
/>
{:else}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
/>
{/if}
</svg>
</button>
</div>
</div>
{/if}
<!-- Node Filter Indicator (top-right corner) -->
{#if isFilterActive()}
<button
onclick={clearPreviewNodeFilter}
class="absolute top-2 right-2 flex items-center gap-1.5 px-2 py-1 bg-exo-dark-gray/80 border border-exo-yellow/40 rounded text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer backdrop-blur-sm"
title="Clear filter"
>
<span class="text-[10px] font-mono tracking-wider">
FILTER: {nodeFilter.size}
</span>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
<!-- Chat Input - Below topology -->
@@ -2061,6 +2326,13 @@
>
{downloadInfo.statusText}
</div>
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
<div
class="text-xs text-red-400/80 font-mono mt-1 break-words"
>
{downloadInfo.errorMessage}
</div>
{/if}
{/if}
</div>
</div>
@@ -2106,6 +2378,9 @@
{@const isImageModel = modelSupportsImageGeneration(
foundModel.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
foundModel.id,
)}
<span
class="flex items-center justify-between gap-2 w-full pr-4"
>
@@ -2132,6 +2407,22 @@
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>
@@ -2204,6 +2495,9 @@
{@const isImageModel = modelSupportsImageGeneration(
model.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
model.id,
)}
<button
type="button"
onclick={() => {
@@ -2244,6 +2538,23 @@
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image editing model"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span
@@ -2564,7 +2875,36 @@
<div
class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden"
>
<TopologyGraph highlightedNodes={highlightedNodes()} />
<TopologyGraph
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
{#if tbBridgeCycles.length > 0}
<div
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
title="Thunderbolt Bridge cycle detected - click to view details"
>
<svg
class="w-3.5 h-3.5 text-yellow-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<span class="text-[10px] font-mono text-yellow-200"
>TB CYCLE</span
>
</div>
{/if}
</div>
</button>
@@ -2993,6 +3333,13 @@
>
{downloadInfo.statusText}
</div>
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
<div
class="text-xs text-red-400/80 font-mono mt-1 break-words"
>
{downloadInfo.errorMessage}
</div>
{/if}
{/if}
</div>
</div>

View File

@@ -172,33 +172,6 @@
}
let downloadOverview = $state<NodeEntry[]>([]);
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
[],
);
async function fetchModels() {
try {
const response = await fetch("/models");
if (response.ok) {
const data = await response.json();
models = data.data || [];
}
} catch (error) {
console.error("Failed to fetch models:", error);
}
}
function getModelTotalBytes(
modelId: string,
downloadTotalBytes: number,
): number {
if (downloadTotalBytes > 0) return downloadTotalBytes;
const model = models.find((m) => m.id === modelId);
if (model?.storage_size_megabytes) {
return model.storage_size_megabytes * 1024 * 1024;
}
return 0;
}
$effect(() => {
try {
@@ -373,7 +346,6 @@
onMount(() => {
// Ensure we fetch at least once when visiting downloads directly
refreshState();
fetchModels();
});
</script>
@@ -482,7 +454,7 @@
{#if model.status !== "completed"}
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(
getModelTotalBytes(model.modelId, model.totalBytes),
model.totalBytes,
)}
</div>
{/if}

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer

View File

@@ -3,12 +3,13 @@ import json
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Literal, cast
from typing import Annotated, Literal, cast
from uuid import uuid4
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -20,7 +21,10 @@ from loguru import logger
from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_IMAGE_CACHE_DIR, EXO_MAX_CHUNK_SIZE
from exo.shared.constants import (
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
)
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
@@ -56,8 +60,15 @@ from exo.shared.types.api import (
PlacementPreview,
PlacementPreviewResponse,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -93,7 +104,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -102,7 +113,19 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
if isinstance(chunk, TokenChunk)
else ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
finish_reason=chunk.finish_reason,
)
],
@@ -162,8 +185,12 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
] = {}
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
@@ -310,11 +337,20 @@ class API:
return placements[new_ids[0]]
async def get_placement_previews(
self, model_id: ModelId
self,
model_id: ModelId,
node_ids: Annotated[list[NodeId] | None, Query()] = None,
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
if len(list(self.state.topology.list_nodes())) == 0:
# Create filtered topology if node_ids specified
if node_ids and len(node_ids) > 0:
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
else:
topology = self.state.topology
if len(list(topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
@@ -327,9 +363,7 @@ class API:
instance_combinations.extend(
[
(sharding, instance_meta, i)
for i in range(
1, len(list(self.state.topology.list_nodes())) + 1
)
for i in range(1, len(list(topology.list_nodes())) + 1)
]
)
# TODO: PDD
@@ -347,7 +381,7 @@ class API:
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
topology=topology,
current_instances=self.state.instances,
)
except ValueError as exc:
@@ -439,11 +473,13 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
@@ -462,7 +498,8 @@ class API:
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
del self._chat_completion_queues[command_id]
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
@@ -470,6 +507,7 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
@@ -498,11 +536,12 @@ class API:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
@@ -511,7 +550,18 @@ class API:
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -529,6 +579,7 @@ class API:
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
@@ -539,6 +590,7 @@ class API:
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
@@ -554,7 +606,19 @@ class API:
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
@@ -571,7 +635,7 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
role="assistant", content=combined_text, tool_calls=tool_calls
),
finish_reason=finish_reason,
)
@@ -729,7 +793,9 @@ class API:
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
with recv as chunks:
async for chunk in chunks:
@@ -838,7 +904,9 @@ class API:
stats: ImageGenerationStats | None = None
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
while images_complete < num_images:
with recv as chunks:
@@ -994,7 +1062,6 @@ class API:
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
@@ -1148,27 +1215,26 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(
event.command_id, None
)
elif event.command_id in self._image_generation_queues:
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
queue = self._image_generation_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._chat_completion_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -87,12 +87,12 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycle

View File

@@ -197,49 +197,6 @@ def get_shard_assignments(
)
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
cycles = cycle_digraph.get_cycles()
expected_length = len(list(cycle_digraph.list_nodes()))
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
if not cycles:
if expected_length > 1:
logger.warning(
f"No cycles of length {expected_length} found even though chosen subgraph contained {expected_length} nodes"
)
return []
cycle = cycles[0]
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycle):
get_thunderbolt = True
logger.debug(f"Using thunderbolt cycle: {get_thunderbolt}")
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle.node_ids[i]
next_node = cycle.node_ids[(i + 1) % len(cycle)]
for connection in cycle_digraph.get_all_connections_between(
source=current_node, sink=next_node
):
if not isinstance(connection, SocketConnection):
continue
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
return hosts
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
cycle_digraph: Topology,
@@ -265,9 +222,6 @@ def get_mlx_jaccl_devices_matrix(
matrix[i][j] = conn.source_rdma_iface
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
)
@@ -279,22 +233,11 @@ def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
) -> Generator[str, None, None]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
def _find_interface_name_for_ip(
ip_address: str, node_network: NodeNetworkInfo
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_network.interfaces:
if interface.ip_address == ip_address:
return interface.name
return None
yield connection.sink_multiaddr.ip_address
def _find_ip_prioritised(
@@ -303,43 +246,19 @@ def _find_ip_prioritised(
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
Priority order:
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
3. Non-Thunderbolt connections
4. Any other IP address
Priority: ethernet > wifi > unknown > thunderbolt
"""
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(
ip, node_network.get(other_node_id, NodeNetworkInfo())
): ip
for ip, _ in ips
if not ips:
return None
other_network = node_network.get(other_node_id, NodeNetworkInfo())
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
en0_ip = iface_map.get("en0")
if en0_ip:
return en0_ip
en1_ip = iface_map.get("en1")
if en1_ip:
return en1_ip
non_thunderbolt_ip = next(
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
)
if non_thunderbolt_ip:
return non_thunderbolt_ip
if ips:
return ips[0][0]
return None
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
def get_mlx_ring_hosts_by_node(
@@ -381,9 +300,6 @@ def get_mlx_ring_hosts_by_node(
node_id, other_node_id, cycle_digraph, node_network
)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
)
@@ -416,9 +332,6 @@ def get_mlx_jaccl_coordinators(
if ip is not None:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
)

View File

@@ -1,13 +1,9 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from typing import Any
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
@@ -48,95 +44,3 @@ def test_http_exception_handler_formats_openai_style() -> None:
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason
def test_image_chunk_with_error_fields() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message="Image generation failed",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Image generation failed"
assert chunk.data == ""
assert chunk.chunk_index == 0
assert chunk.total_chunks == 1
assert chunk.image_index == 0
def test_image_chunk_without_error() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="base64encodeddata",
chunk_index=0,
total_chunks=1,
image_index=0,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
assert chunk.data == "base64encodeddata"

View File

@@ -3,7 +3,6 @@ import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
@@ -14,7 +13,7 @@ from exo.master.tests.conftest import (
)
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
@@ -273,45 +272,6 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
# act
hosts = get_hosts_from_subgraph(topology)
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip="169.254.0.1", port=1234),
Host(ip="169.254.0.2", port=1234),
Host(ip="169.254.0.3", port=1234),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
# arrange
node_a_id = NodeId()

View File

@@ -30,6 +30,7 @@ from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -46,6 +47,7 @@ from exo.utils.info_gatherer.info_gatherer import (
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
ThunderboltBridgeInfo,
)
@@ -225,6 +227,21 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
node_thunderbolt_bridge = {
key: value
for key, value in state.node_thunderbolt_bridge.items()
if key != event.node_id
}
# Only recompute cycles if the leaving node had TB bridge enabled
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
leaving_node_had_tb_enabled = (
leaving_node_status is not None and leaving_node_status.enabled
)
thunderbolt_bridge_cycles = (
topology.get_thunderbolt_bridge_cycles(node_thunderbolt_bridge, node_network)
if leaving_node_had_tb_enabled
else [list(cycle) for cycle in state.thunderbolt_bridge_cycles]
)
return state.model_copy(
update={
"downloads": downloads,
@@ -235,6 +252,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
"node_thunderbolt_bridge": node_thunderbolt_bridge,
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
}
)
@@ -312,6 +331,22 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
case ThunderboltBridgeInfo():
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
**state.node_thunderbolt_bridge,
event.node_id: info.status,
}
update["node_thunderbolt_bridge"] = new_tb_bridge
# Only recompute cycles if the enabled status changed
old_status = state.node_thunderbolt_bridge.get(event.node_id)
old_enabled = old_status.enabled if old_status else False
new_enabled = info.status.enabled
if old_enabled != new_enabled:
update["thunderbolt_bridge_cycles"] = (
topology.get_thunderbolt_bridge_cycles(
new_tb_bridge, state.node_network
)
)
return state.model_copy(update=update)

View File

@@ -49,3 +49,7 @@ LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_ENABLE_IMAGE_MODELS = (
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
)

View File

@@ -9,6 +9,7 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt, field_validator
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
@@ -410,161 +411,166 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
# "flux1-schnell": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23782357120),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "flux1-dev": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23802816640),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image-edit-2509": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
}
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
@@ -627,7 +633,7 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
@@ -650,7 +656,7 @@ async def get_safetensors_size(model_id: ModelId) -> Memory:
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)

View File

@@ -7,6 +7,11 @@ import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import (
InterfaceType,
NodeNetworkInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.topology import (
Connection,
Cycle,
@@ -188,24 +193,25 @@ class Topology:
cycles.append(Cycle(node_ids=[node_id]))
return cycles
def get_cycles_tb(self) -> list[Cycle]:
tb_edges = [
def get_rdma_cycles(self) -> list[Cycle]:
rdma_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
if isinstance(conn, RDMAConnection)
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
rdma_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = (
rx.PyDiGraph()
)
rdma_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
for u, v, conn in rdma_edges:
rdma_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycle_idxs = rx.simple_cycles(rdma_graph)
cycles: list[Cycle] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
cycle = Cycle(node_ids=[rdma_graph[idx] for idx in cycle_idx])
cycles.append(cycle)
return cycles
@@ -219,18 +225,83 @@ class Topology:
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
def is_rdma_cycle(self, cycle: Cycle) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:
continue
has_tb = False
has_rdma = False
for edge in self._graph.get_all_edge_data(rid, neighbor_rid):
if edge.is_thunderbolt():
has_tb = True
if isinstance(edge, RDMAConnection):
has_rdma = True
break
if not has_tb:
if not has_rdma:
return False
return True
def get_thunderbolt_bridge_cycles(
self,
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
"""
enabled_nodes = {
node_id
for node_id, status in node_tb_bridge_status.items()
if status.enabled
}
if len(enabled_nodes) < 3:
return []
thunderbolt_ips = _get_ips_with_interface_type(
enabled_nodes, node_network, "thunderbolt"
)
# Build subgraph with only TB bridge enabled nodes and thunderbolt connections
graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = rx.PyDiGraph()
node_to_idx: dict[NodeId, int] = {}
for node_id in enabled_nodes:
if node_id in self._vertex_indices:
node_to_idx[node_id] = graph.add_node(node_id)
for u, v, conn in self._graph.weighted_edge_list():
source_id, sink_id = self._graph[u], self._graph[v]
if source_id not in node_to_idx or sink_id not in node_to_idx:
continue
# Include connection if it's over a thunderbolt interface
if (
isinstance(conn, SocketConnection)
and conn.sink_multiaddr.ip_address in thunderbolt_ips
):
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
if isinstance(conn, RDMAConnection):
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
return [
[graph[idx] for idx in cycle]
for cycle in rx.simple_cycles(graph)
if len(cycle) > 2
]
def _get_ips_with_interface_type(
node_ids: set[NodeId],
node_network: Mapping[NodeId, NodeNetworkInfo],
interface_type: InterfaceType,
) -> set[str]:
"""Get all IP addresses on interfaces of the specified type for the given nodes."""
ips: set[str] = set()
for node_id in node_ids:
network_info = node_network.get(node_id, NodeNetworkInfo())
for iface in network_info.interfaces:
if iface.interface_type == interface_type:
ips.add(iface.ip_address)
return ips

View File

@@ -54,6 +54,18 @@ class ChatCompletionMessageText(BaseModel):
text: str
class ToolCallItem(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
index: int | None = None
type: Literal["function"] = "function"
function: ToolCallItem
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: (
@@ -61,7 +73,7 @@ class ChatCompletionMessage(BaseModel):
) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator
from enum import Enum
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
@@ -8,24 +7,29 @@ from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
class ChunkType(str, Enum):
Token = "Token"
Image = "Image"
from .worker.runner_response import ToolCallItem
class BaseChunk(TaggedModel):
idx: int
model: ModelId
class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: FinishReason | None = None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
class ErrorChunk(BaseChunk):
error_message: str
finish_reason: Literal["error"] = "error"
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):
@@ -63,4 +67,4 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Self
from typing import Literal, Self
import psutil
@@ -48,9 +48,13 @@ class SystemPerformanceProfile(CamelCaseModel):
ecpu_usage: float = 0.0
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
class NetworkInterfaceInfo(CamelCaseModel):
name: str
ip_address: str
interface_type: InterfaceType = "unknown"
class NodeIdentity(CamelCaseModel):
@@ -71,3 +75,11 @@ class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []
class ThunderboltBridgeStatus(CamelCaseModel):
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
enabled: bool
exists: bool
service_name: str | None = None

View File

@@ -13,6 +13,7 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
ThunderboltBridgeStatus,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
@@ -51,6 +52,10 @@ class State(CamelCaseModel):
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:

View File

@@ -21,9 +21,6 @@ class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
def is_thunderbolt(self) -> bool:
return True
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
@@ -31,9 +28,6 @@ class SocketConnection(FrozenModel):
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
class Connection(FrozenModel):
source: NodeId

View File

@@ -1,7 +1,12 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
from exo.shared.types.api import (
FinishReason,
GenerationStats,
ImageGenerationStats,
ToolCallItem,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -48,5 +53,9 @@ class PartialImageResponse(BaseRunnerResponse):
yield name, value
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -19,6 +19,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
@@ -34,6 +35,142 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
IS_DARWIN = sys.platform == "darwin"
async def _get_thunderbolt_devices() -> set[str] | None:
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listallhardwareports"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listallhardwareports failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
output = result.stdout.decode()
thunderbolt_devices: set[str] = set()
current_port: str | None = None
for line in output.splitlines():
line = line.strip()
if line.startswith("Hardware Port:"):
current_port = line.split(":", 1)[1].strip()
elif line.startswith("Device:") and current_port:
device = line.split(":", 1)[1].strip()
if "thunderbolt" in current_port.lower():
thunderbolt_devices.add(device)
current_port = None
return thunderbolt_devices
async def _get_bridge_services() -> dict[str, str] | None:
"""Get mapping of bridge device -> service name from network service order.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listnetworkserviceorder"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listnetworkserviceorder failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
# Parse service order to find bridge devices and their service names
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
service_order_output = result.stdout.decode()
bridge_services: dict[str, str] = {} # device -> service name
current_service: str | None = None
for line in service_order_output.splitlines():
line = line.strip()
# Match "(N) Service Name" or "(*) Service Name" (disabled)
# but NOT "(Hardware Port: ...)" lines
if (
line
and line.startswith("(")
and ")" in line
and not line.startswith("(Hardware Port:")
):
paren_end = line.index(")")
if paren_end + 2 <= len(line):
current_service = line[paren_end + 2 :]
# Match "(Hardware Port: ..., Device: bridgeX)"
elif current_service and "Device: bridge" in line:
# Extract device name from "..., Device: bridge0)"
device_start = line.find("Device: ") + len("Device: ")
device_end = line.find(")", device_start)
if device_end > device_start:
device = line[device_start:device_end]
bridge_services[device] = current_service
return bridge_services
async def _get_bridge_members(bridge_device: str) -> set[str]:
"""Get member interfaces of a bridge device via ifconfig."""
result = await anyio.run_process(
["ifconfig", bridge_device],
check=False,
)
if result.returncode != 0:
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
return set()
members: set[str] = set()
ifconfig_output = result.stdout.decode()
for line in ifconfig_output.splitlines():
line = line.strip()
if line.startswith("member:"):
parts = line.split()
if len(parts) > 1:
members.add(parts[1])
return members
async def _find_thunderbolt_bridge(
bridge_services: dict[str, str], thunderbolt_devices: set[str]
) -> str | None:
"""Find the service name of a bridge containing Thunderbolt interfaces.
Returns the service name if found, None otherwise.
"""
for bridge_device, service_name in bridge_services.items():
members = await _get_bridge_members(bridge_device)
if members & thunderbolt_devices: # intersection is non-empty
return service_name
return None
async def _is_service_enabled(service_name: str) -> bool | None:
"""Check if a network service is enabled.
Returns True if enabled, False if disabled, None on error.
"""
result = await anyio.run_process(
["networksetup", "-getnetworkserviceenabled", service_name],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -getnetworkserviceenabled '{service_name}' "
f"failed with code {result.returncode}: {result.stderr.decode()}"
)
return None
stdout = result.stdout.decode().strip().lower()
return stdout == "enabled"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
@@ -58,6 +195,66 @@ class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class ThunderboltBridgeInfo(TaggedModel):
status: ThunderboltBridgeStatus
@classmethod
async def gather(cls) -> Self | None:
"""Check if a Thunderbolt Bridge network service is enabled on this node.
Detection approach:
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
2. Find bridge devices from network service order (not hardware ports, as
bridges may not appear there)
3. Check each bridge's members via ifconfig
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
5. Check if that network service is enabled
"""
if not IS_DARWIN:
return None
def _no_bridge_status() -> Self:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=False, service_name=None
)
)
try:
tb_devices = await _get_thunderbolt_devices()
if tb_devices is None:
return _no_bridge_status()
bridge_services = await _get_bridge_services()
if not bridge_services:
return _no_bridge_status()
tb_service_name = await _find_thunderbolt_bridge(
bridge_services, tb_devices
)
if not tb_service_name:
return _no_bridge_status()
enabled = await _is_service_enabled(tb_service_name)
if enabled is None:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=True, service_name=tb_service_name
)
)
return cls(
status=ThunderboltBridgeStatus(
enabled=enabled,
exists=True,
service_name=tb_service_name,
)
)
except Exception as e:
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
return None
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@@ -111,6 +308,7 @@ GatheredInfo = (
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| ThunderboltBridgeInfo
| NodeConfig
| MiscData
| StaticNodeInformation
@@ -125,6 +323,7 @@ class InfoGatherer:
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
@@ -133,6 +332,7 @@ class InfoGatherer:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._monitor_thunderbolt_bridge_status)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
@@ -206,6 +406,17 @@ class InfoGatherer:
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return

View File

@@ -1,11 +1,11 @@
import socket
import sys
from subprocess import CalledProcessError
from subprocess import CalledProcessError, run
import psutil
from anyio import run_process
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo
async def get_friendly_name() -> str:
@@ -28,6 +28,38 @@ async def get_friendly_name() -> str:
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
"""Parse networksetup -listallhardwareports to get interface types."""
if sys.platform != "darwin":
return {}
try:
result = run(
["networksetup", "-listallhardwareports"], capture_output=True, text=True
)
except Exception:
return {}
types: dict[str, InterfaceType] = {}
current_type: InterfaceType = "unknown"
for line in result.stdout.splitlines():
if line.startswith("Hardware Port:"):
port_name = line.split(":", 1)[1].strip()
if "Wi-Fi" in port_name:
current_type = "wifi"
elif "Ethernet" in port_name or "LAN" in port_name:
current_type = "ethernet"
elif port_name.startswith("Thunderbolt"):
current_type = "thunderbolt"
else:
current_type = "unknown"
elif line.startswith("Device:"):
device = line.split(":", 1)[1].strip()
types[device] = current_type
return types
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
"""
Retrieves detailed network interface information on macOS.
@@ -36,13 +68,18 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
Returns a list of NetworkInterfaceInfo objects.
"""
interfaces_info: list[NetworkInterfaceInfo] = []
interface_types = _get_interface_types_from_networksetup()
for iface, services in psutil.net_if_addrs().items():
for service in services:
match service.family:
case socket.AF_INET | socket.AF_INET6:
interfaces_info.append(
NetworkInterfaceInfo(name=iface, ip_address=service.address)
NetworkInterfaceInfo(
name=iface,
ip_address=service.address,
interface_type=interface_types.get(iface, "unknown"),
)
)
case _:
pass

View File

@@ -40,9 +40,31 @@ from exo.worker.download.huggingface_utils import (
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
class HuggingFaceAuthenticationError(Exception):
"""Raised when HuggingFace returns 401/403 for a model download."""
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
if status_code == 401 and token is None:
return (
f"Model '{model_id}' requires authentication. "
f"Set HF_TOKEN in the app's Advanced settings, set the HF_TOKEN environment variable, or run `hf auth login`. "
f"Get a token at https://huggingface.co/settings/tokens"
)
elif status_code == 403:
return (
f"Access denied to '{model_id}'. "
f"Please accept the model terms at https://huggingface.co/{model_id}"
)
else:
return f"Authentication failed for '{model_id}' (HTTP {status_code})"
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -147,6 +169,8 @@ async def fetch_file_list_with_retry(
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
if attempt == n_attempts - 1:
raise e
@@ -167,6 +191,9 @@ async def _fetch_file_list(
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
@@ -256,6 +283,9 @@ async def file_meta(
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(model_id, revision, path, redirected_location)
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -279,6 +309,8 @@ async def download_file_with_retry(
return await _download_file(
model_id, revision, path, target_dir, on_progress
)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
@@ -322,6 +354,9 @@ async def _download_file(
):
if r.status == 404:
raise FileNotFoundError(f"File not found: {url}")
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
)
@@ -463,7 +498,7 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
logger.debug(f"Downloading {shard.model_card.model_id=}")
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
@@ -476,7 +511,7 @@ async def download_shard(
allow_patterns = await resolve_allow_patterns(shard)
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
logger.debug(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(

View File

@@ -68,7 +68,11 @@ def get_hf_home() -> Path:
async def get_hf_token() -> str | None:
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
"""Retrieve the Hugging Face token from HF_TOKEN env var or HF_HOME directory."""
# Check environment variable first
if token := os.environ.get("HF_TOKEN"):
return token
# Fall back to file-based token
token_path = get_hf_home() / "token"
if await aios.path.exists(token_path):
async with aiofiles.open(token_path, "r") as f:

View File

@@ -3,6 +3,8 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from loguru import logger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -19,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.load(model_id)
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -166,7 +168,7 @@ class ResumableShardDownloader(ShardDownloader):
yield await task
# TODO: except Exception
except Exception as e:
print("Error downloading shard:", e)
logger.error("Error downloading shard:", e)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -145,6 +145,10 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -252,10 +256,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call

View File

@@ -170,10 +170,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
return tokenizer
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
"""
Normalize tool_calls in a message dict.
OpenAI format has tool_calls[].function.arguments as a JSON string,
but some chat templates (e.g., GLM) expect it as a dict.
"""
tool_calls = msg_dict.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
return
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
if not isinstance(tc, dict):
continue
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if not isinstance(func, dict):
continue
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if isinstance(args, str):
with contextlib.suppress(json.JSONDecodeError):
func["arguments"] = json.loads(args)
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
tools = chat_task_data.tools
formatted_messages: list[dict[str, Any]] = []
for message in messages:
@@ -386,15 +409,19 @@ def apply_chat_template(
continue
# Null values are not valid when applying templates in tokenizer
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
dumped: dict[str, Any] = message.model_dump()
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
_normalize_tool_calls(msg_dict)
formatted_messages.append(msg_dict)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
tools=tools,
)
logger.info(prompt)

View File

@@ -38,6 +38,7 @@ from exo.shared.types.tasks import (
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
@@ -443,7 +444,33 @@ class Worker:
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def download_with_error_handling() -> None:
try:
await self.shard_downloader.ensure_shard(task.shard_metadata)
except Exception as e:
error_message = str(e)
logger.error(
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
)
failed_status = DownloadFailed(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
error_message=error_message,
)
self.download_status[task.shard_metadata.model_card.model_id] = (
failed_status
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed_status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Failed
)
)
self._tg.start_soon(download_with_error_handling)
async def _forward_events(self) -> None:
with self.event_receiver as events:

View File

@@ -20,6 +20,7 @@ from exo.shared.types.tasks import (
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadProgress,
)
@@ -122,7 +123,8 @@ def _model_needs_download(
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
download_status[model_id],
(DownloadOngoing, DownloadCompleted, DownloadFailed),
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk

View File

@@ -1,8 +1,9 @@
import base64
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Literal
from typing import Any, Callable, Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -13,11 +14,12 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
@@ -42,6 +44,8 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
ToolCallItem,
ToolCallResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -154,6 +158,9 @@ def main(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
@@ -244,17 +251,53 @@ def main(
mlx_generator, tokenizer
)
# TODO: Add tool call parser here
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
for response in mlx_generator:
match response:
case GenerationResponse():
if device_rank == 0:
if (
device_rank == 0
and response.finish_reason == "error"
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
error_message=response.text,
model=shard_metadata.model_card.model_id,
),
)
)
elif device_rank == 0:
assert response.finish_reason not in (
"error",
"tool_calls",
"function_call",
)
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
@@ -263,6 +306,17 @@ def main(
),
)
)
case ToolCallResponse():
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
),
)
)
# can we make this more explicit?
except Exception as e:
@@ -270,11 +324,8 @@ def main(
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
@@ -328,18 +379,14 @@ def main(
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=0,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message=str(e),
),
@@ -396,13 +443,8 @@ def main(
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=0,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message=str(e),
),
@@ -446,6 +488,18 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
@@ -526,7 +580,6 @@ def _send_image_chunk(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
@@ -568,6 +621,196 @@ def _process_image_response(
)
def parse_tool_calls(
responses: Generator[GenerationResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
return ToolCallItem(name=name, arguments=json.dumps(args))
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
@@ -140,6 +140,13 @@ class EventCollector:
pass
class MockTokenizer:
tool_parser = None
tool_call_start = None
tool_call_end = None
has_tool_calling = False
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -171,7 +178,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,

59
uv.lock generated
View File

@@ -376,8 +376,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -413,7 +413,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = ">=0.14.2" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
@@ -458,16 +458,6 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"
@@ -1004,8 +994,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1032,18 +1022,12 @@ wheels = [
name = "mlx"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
resolution-markers = [
"sys_platform == 'linux'",
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
]
@@ -1056,6 +1040,14 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.4.dev20260121+fbe306f9"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.3"
@@ -1086,7 +1078,7 @@ version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1094,16 +1086,6 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "mlx-metal"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"
@@ -2227,6 +2209,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
]
[[package]]
name = "torch"
version = "2.9.1"