Compare commits

..

15 Commits

Author SHA1 Message Date
Evan
5de67883c0 fix: deprioritise uncertain ethernet devices
we were placing coordinators on uncertain devices (enX+) that are listed
as "USB LAN" - these could be thunderbolt ports breaking RDMA instances
2026-01-23 18:22:53 +00:00
Jake Hillion
6dbbe7797b downloads: add download and delete buttons to downloads UI
The downloads page showed model download progress but provided no way
for users to trigger downloads or remove completed models from disk.

Added API endpoints (POST /download/start, DELETE /download/{node_id}/{model_id})
that send StartDownload and DeleteDownload commands via the download_command_sender.
Updated the dashboard downloads page with per-model buttons: a download button
for incomplete downloads and a delete button for completed ones.

This allows users to manage downloads directly from the UI without needing
to trigger downloads through other means.

Test plan:
- Deployed on a 3 machine cluster. Did several downloads/deletions - all
  work and the dashboard updates relatively fluently. It takes roughly 5
  seconds to render a 131GB model deletion which isn't too bad.
2026-01-23 18:11:17 +00:00
Jake Hillion
9357503c6f downloads: refactor to run at node level
The Worker previously owned the ShardDownloader directly via dependency
injection, which prevented --no-worker nodes from downloading and made
it impossible for multiple Workers to share a single downloader instance.

Moved download functionality to a new DownloadCoordinator component at
the Node level that communicates via the DOWNLOAD_COMMANDS pub/sub topic.
Workers now send StartDownload commands instead of calling the downloader
directly, and receive progress updates through the event-sourced state.

This decouples downloads from the Worker lifecycle and enables future
features like UI-triggered downloads to specific nodes and multi-worker
download sharing.

Test plan:
- Mostly tested in the next PR that adds explicit downloads/deletions to
  the dashboard.
- Started a model that isn't downloaded - it works.
2026-01-23 18:04:09 +00:00
ciaranbor
ba19940828 Fix regenerate for image models (#1263)
## Motivation

The 'regenerate' button was hardcoded to chat completion. Clicking
'regenerate' for image request would result in an error after the model
is loaded

## Changes

Store request type and dispatch to appropriate request upon regeneration

## Why It Works

We make sure to repeat the same request type as was performed originally

## Test Plan

### Manual Testing

Checked 'regenerate' works for chat completion, image generation, image
editing
2026-01-23 16:33:01 +00:00
Jake Hillion
f255345a1a dashboard: decouple prettier-svelte from dashboard source
The prettier-svelte formatter depended on the full dashboard build
(dashboardFull), causing the devshell to rebuild whenever any dashboard
source file changed.

Created a deps-only dream2nix derivation (deps.nix) that uses a stub
source containing only package.json, package-lock.json, and minimal
files for vite to succeed. Updated prettier-svelte to use this
derivation instead of dashboardFull.

The stub source is constant unless lockfiles change, so prettier-svelte
and the devshell no longer rebuild when dashboard source files are
modified.

Test plan:
- nix flake check passed
- nix fmt successfully formatted svelte files
2026-01-23 15:16:48 +00:00
ciaranbor
a1939c89f2 Enable UI settings for image editing (#1258)
## Motivation

Image editing was missing UI controls for quality, output format, and
advanced parameters that text-to-image generation already supported.

## Changes

- Added quality, output_format, and advanced_params to image edit API
endpoints
- Extended isImageModel check to include image editing models

## Why It Works

The API now accepts and forwards these settings for image edits, and the
UI displays the appropriate controls for image editing models.

## Test Plan

### Manual Testing

Verified parameters can be set in UI and that they progagate through to
model inference
2026-01-23 13:37:25 +00:00
ciaranbor
cb9c9ee55c Enable generating multiple images. Optionally stream partial images (#1251)
## Motivation

Support OpenAI API `n` setting

## Changes

- Users can select `n` to generate more than one image with the same
prompt
- each image uses a different seed -> different results
- `stream` and `partial_images` settings can be overwritten in UI
2026-01-23 11:19:58 +00:00
Alex Cheema
df240f834d Fix GLM and Kimi tool calling crashes (#1255)
## Motivation

Fixes tool calling crashes with GLM-4.7-Flash and Kimi-K2 models.

Related: #1254

Two distinct issues were causing crashes:
1. **Tool parser crashes** - The upstream GLM47 and Kimi tool parsers
call `.group()` on regex matches without checking for `None`, causing
`AttributeError` when the model outputs malformed tool calls
2. **Chat template crashes** - GLM's chat template expects
`tool_calls[].function.arguments` to be a dict, but OpenAI format
provides it as a JSON string, causing `'str object' has no attribute
'items'`

## Changes

**`src/exo/worker/runner/runner.py`:**
- Add `patch_glm_tokenizer()` - fixed version of mlx_lm's glm47 parser
with None checks
- Fix `patch_kimi_tokenizer()` - add None checks before calling
`.group()` on regex matches
- Add `ValueError` and `AttributeError` to exception handling in
`parse_tool_calls()`

**`src/exo/worker/engines/mlx/utils_mlx.py`:**
- Add `_normalize_tool_calls()` - parses
`tool_calls[].function.arguments` from JSON string to dict for templates
that expect dicts (like GLM-4.7-Flash)

## Why It Works

1. **Parser fixes**: By checking if regex matches are `None` before
calling `.group()`, we can raise a proper `ValueError` instead of
crashing with `AttributeError`

2. **Template fix**: The GLM-4.7-Flash chat template iterates over
arguments with `.items()`:
   ```jinja2
   {% set _args = tc.arguments %}{% for k, v in _args.items() %}
   ```
OpenAI format has `arguments` as a JSON string.
`_normalize_tool_calls()` parses this to a dict before passing to the
template.

## Test Plan

### Manual Testing
- Hardware: Mac with GLM-4.7-Flash-4bit model
- Tested tool calling with GLM model - no longer crashes

### Automated Testing
- Existing tests pass (`uv run pytest`)
- Type checking passes (`uv run basedpyright`)
- Linting passes (`uv run ruff check`)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-23 01:39:59 +00:00
ciaranbor
cd125b3b8c Use icon for image editing models (#1252)
## Motivation

Visual indicator for image editing models

## Changes

Add pencil icon to edit models in model list
2026-01-22 22:37:34 +00:00
Alex Cheema
b783a21399 dashboard: add placement filter by clicking topology nodes (#1248)
## Motivation

When selecting a model for placement, users often want to see placements
that utilize specific nodes in their cluster. Currently there's no way
to filter the placement previews to focus on configurations that include
particular machines.

## Changes

- **Backend**: Added `node_ids` query parameter to the
`/placement-previews` API endpoint. When provided, the endpoint filters
the topology to only include the specified nodes before generating
placements using the new `Topology.filter_to_nodes()` method.

- **Topology class**: Added `filter_to_nodes(node_ids)` method that
creates a new topology containing only the specified nodes and edges
between them.

- **App store**: Added `previewNodeFilter` state to track selected
nodes, with methods to toggle/clear the filter. Automatically cleans up
filter when nodes are removed from the cluster and re-fetches previews
when topology changes.

- **TopologyGraph component**: Added click handlers to toggle node
filter selection, hover effects to indicate clickable nodes, and visual
styling (yellow highlight for selected, dimmed for filtered-out nodes).

- **Main page**: Added filter indicator in top-right corner of topology
showing active filter count with a clear button.

## Why It Works

The filtering happens at the backend/placement generation level rather
than just filtering the results. This ensures we see all valid placement
combinations for the selected nodes, not just a subset that happened to
be generated for the full topology.

The visual feedback uses the same rendering approach as the existing
highlight system - state is tracked in Svelte and applied during render,
so it persists across data updates without flickering.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Click a node in topology → should show yellow highlight and filter
indicator
- Click another node → indicator shows "2 nodes", previews update to
show only placements using both
- Hover over nodes → subtle yellow highlight indicates they're clickable
- Click X on filter indicator → clears filter, shows all placements
again
- Disconnect a node while it's in filter → filter auto-removes that node

### Automated Testing
- Existing tests cover the Topology class; the new `filter_to_nodes`
method follows the same patterns

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 22:12:57 +00:00
Alex Cheema
43f12f5d08 Replace LaunchDaemon with dynamic Thunderbolt Bridge loop detection (#1222)
## Motivation

The previous approach installed a LaunchDaemon plist that ran
periodically to disable Thunderbolt Bridge. This required full admin
privileges upfront and ran regardless of whether a problematic loop
existed.

This change replaces that with dynamic detection - only prompting the
user when an actual TB bridge loop with 3+ machines is detected, and
using fine-grained SCPreferences authorization instead of full admin.

## Changes

**Backend (Python):**
- Added `ThunderboltBridgeStatus` model to track bridge enabled/exists
state per node
- Added `node_thunderbolt_bridge` and `thunderbolt_bridge_cycles` fields
to State
- Added `get_thunderbolt_bridge_cycles()` method to Topology class
- **Robust TB bridge detection:**
- Finds bridge network services from `-listnetworkserviceorder` (not
`-listallhardwareports` which can miss bridges)
- Checks each bridge's member interfaces via `ifconfig` to verify it
contains Thunderbolt interfaces
- Handles varying service names (e.g., "TB Bridge", "Thunderbolt
Bridge", "Bridge (bridge0)")
  - Includes `service_name` in status for correct disable commands
  - Added warning logs for all error cases in detection
- Updated `apply.py` to handle the new event type and recompute cycles
on node timeout

**Swift App:**
- New `ThunderboltBridgeService` that monitors for cycles from cluster
state
- Shows NSAlert when a cycle with >2 machines is detected
- Uses `SCPreferencesCreateWithAuthorization` with
`system.services.systemconfiguration.network` right for targeted
permissions
- **Auto-cleanup of legacy LaunchDaemon:** On app startup, checks for
and removes old plist/scripts (non-fatal if user cancels)
- **Periodic local network checking:** Re-checks every 10s so the
warning disappears when user grants permission
- **Fixed ClusterState model:** Updated to decode new granular state
fields (`nodeIdentities`, `nodeMemory`, `nodeSystem`,
`nodeThunderboltBridge`) with computed `nodeProfiles` property for
backwards compatibility
- **Fixed Topology model:** Updated to match actual JSON structure where
`nodes` is an array of strings (not objects) and `connections` is a
nested map (not flat array)
- Cleaned up `NetworkSetupHelper` by removing daemon installation code
(now only handles uninstall)

**Dashboard:**
- Added yellow warning badge on topology when TB bridge cycle detected
- On hover: highlights affected nodes in yellow on the topology graph
- Shows which machines are in the cycle with friendly names
- Provides copy-paste terminal command with the correct service name:
  ```
  sudo networksetup -setnetworkserviceenabled "<service-name>" off
  ```
- Warning appears in all topology views (full, welcome, and minimized
chat sidebar)
- **Debug mode:** Shows "TB:ON" or "TB:OFF" status next to each node in
the topology

## Why It Works

- Cycle detection happens on the backend where we have full topology
information
- Only cycles with 3+ machines are flagged (2-node connections are fine)
- TB bridge detection is robust:
- Uses `-listnetworkserviceorder` to find bridges (works on all machines
tested)
- Verifies bridge membership via `ifconfig` to confirm Thunderbolt
interfaces
  - Handles different service names across machines
- The Swift app reacts to detected cycles and prompts the user once per
cycle
- The dashboard provides visual feedback and actionable instructions
- `SCPreferencesCreateWithAuthorization` provides the minimal
permissions needed to modify network service state
- Legacy LaunchDaemon is automatically cleaned up on first launch with
this version

## Test Plan

### Manual Testing
Here EXO detected a TB bridge cycle:

#### Dashboard:
<img width="1363" height="884" alt="Screenshot 2026-01-21 at 10 07
30 PM"
src="https://github.com/user-attachments/assets/7da9c621-0c91-42c4-898e-4952188a1f61"
/>

#### Hovering the warning:
<img width="359" height="279" alt="Screenshot 2026-01-21 at 16 30 57"
src="https://github.com/user-attachments/assets/05501dcf-3d4a-4704-9f38-257748c05a53"
/>

#### macOS app warning popup:
<img width="270" height="410" alt="Screenshot 2026-01-21 at 16 29 08"
src="https://github.com/user-attachments/assets/45714427-08c3-4fb4-9e61-144925c51adf"
/>

### Which then asks for the user's password:
<img width="263" height="372" alt="Screenshot 2026-01-21 at 16 29 28"
src="https://github.com/user-attachments/assets/7502e591-596d-4128-8cf5-6a12674e27bc"
/>

Which when entered, successfully disables bridge and no longer shows the
warning on dashboard.

#### When it fails it shows the error message:
<img width="263" height="234" alt="Screenshot 2026-01-21 at 14 45 38"
src="https://github.com/user-attachments/assets/2d10b3d5-69d7-46ea-b631-d52d8651ab41"
/>

### Automated Testing
- Type checker: 0 errors (`uv run basedpyright`)
- Linter: All checks passed (`uv run ruff check`)
- Tests: 118 passed (`uv run pytest`)
- Dashboard: Builds successfully (`npm run build`)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 21:53:05 +00:00
ciaranbor
8027d7933f Ciaran/hf token (#1250)
## Motivation

black-forest-labs models require hf auth and signup to download. We
don't handle this gracefully.
https://github.com/exo-explore/exo/issues/1242

## Changes

- Handle auth errors
- Surface error to UI and suggest resolution
- Support using HF_TOKEN env variable for auto
- Hide image functionality behind `EXO_ENABLE_IMAGE_MODELS=true` for now

## Why It Works

Users are presented with actionable feedback when issue occurs

## Test Plan

### Manual Testing

Confirmed loading black-forest-labs model in UI presents the issue in
the UI.
Confirmed both `hf auto login` and setting `HF_TOKEN` resolve the issue
2026-01-22 20:39:53 +00:00
Evan
ac6efa747b add kimi tool parseing
this patches the kimi tokenizer to add tool calling - it can be reverted
once upstream support is added for kimi-k2
2026-01-22 11:49:25 +00:00
Evan
2e3c33db6d implement mlx-lm tool calling
splits up the runners generation chunks into tool calls, tokens and
errors, and writes tool call chunks when the upstream parser detects
them.
2026-01-22 11:49:25 +00:00
rltakashige
fc8e6ad06b Reduce download log spam (#1249)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-22 11:28:36 +00:00
68 changed files with 3998 additions and 2566 deletions

View File

@@ -276,23 +276,24 @@ class BatchGenerator:
logprobs: mx.array
finish_reason: Optional[str]
unprocessed_prompts: List[Any]
def __init__(
self,
model: nn.Module,
model,
max_tokens: int = ...,
stop_tokens: Optional[set[int]] = ...,
stop_tokens: Optional[set] = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
completion_batch_size: int = ...,
prefill_batch_size: int = ...,
prefill_step_size: int = ...,
) -> None: ...
def insert(
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
) -> List[int]: ...
def stats(self) -> BatchStats: ...
def next(self) -> List[Response]: ...
self, prompts, max_tokens: Union[List[int], int, None] = ...
): # -> list[Any]:
...
def stats(self): # -> BatchStats:
...
def next(self): # -> list[Any]:
...
def batch_generate(
model,

View File

@@ -39,18 +39,12 @@ class StreamingDetokenizer:
"""
__slots__ = ...
tokens: list[int]
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
@property
def text(self) -> str:
"""The full text decoded so far."""
...
@property
def last_segment(self) -> str:
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
...
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
@@ -114,7 +108,6 @@ class TokenizerWrapper:
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
eos_token_ids: list[int] | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int

View File

@@ -116,45 +116,6 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Model Storage
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
## Creating Model Instances via API
When testing with the API, you must first create a model instance before sending chat completions:
```bash
# 1. Get instance previews for a model
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
# 2. Create an instance from the first valid preview
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
# 3. Wait for the runner to become ready (check logs for "runner ready")
# 4. Send chat completions using the full model ID
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
```
## Logs
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
### Distributed Testing
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
```bash
# On each machine in the test cluster, use the same unique namespace
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
```
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.

View File

@@ -14,6 +14,7 @@ struct ContentView: View {
@EnvironmentObject private var networkStatusService: NetworkStatusService
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
@EnvironmentObject private var updater: SparkleUpdater
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
@State private var focusedNode: NodeViewModel?
@State private var deletingInstanceIDs: Set<String> = []
@State private var showAllNodes = false
@@ -24,6 +25,8 @@ struct ContentView: View {
@State private var bugReportMessage: String?
@State private var uninstallInProgress = false
@State private var pendingNamespace: String = ""
@State private var pendingHFToken: String = ""
@State private var pendingEnableImageModels = false
var body: some View {
VStack(alignment: .leading, spacing: 12) {
@@ -303,6 +306,49 @@ struct ContentView: View {
.disabled(pendingNamespace == controller.customNamespace)
}
}
VStack(alignment: .leading, spacing: 4) {
Text("HuggingFace Token")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
SecureField("optional", text: $pendingHFToken)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingHFToken = controller.hfToken
}
Button("Save & Restart") {
controller.hfToken = pendingHFToken
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingHFToken == controller.hfToken)
}
}
Divider()
HStack {
Toggle(
"Enable Image Models (experimental)", isOn: $pendingEnableImageModels
)
.toggleStyle(.switch)
.font(.caption2)
.onAppear {
pendingEnableImageModels = controller.enableImageModels
}
Spacer()
Button("Save & Restart") {
controller.enableImageModels = pendingEnableImageModels
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingEnableImageModels == controller.enableImageModels)
}
HoverButton(title: "Check for Updates", small: true) {
updater.checkForUpdates()
}
@@ -423,6 +469,44 @@ struct ContentView: View {
}
}
/// Shows TB bridge status for all nodes from exo cluster state
private var clusterThunderboltBridgeView: some View {
let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]
let localNodeId = stateService.localNodeId
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
return VStack(alignment: .leading, spacing: 1) {
if bridgeStatuses.isEmpty {
Text("Cluster TB Bridge: No data")
.font(.caption2)
.foregroundColor(.secondary)
} else {
Text("Cluster TB Bridge Status:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(Array(bridgeStatuses.keys.sorted()), id: \.self) { nodeId in
if let status = bridgeStatuses[nodeId] {
let nodeName =
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
let isLocal = nodeId == localNodeId
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
let statusText =
!status.exists
? "N/A"
: (status.enabled ? "Enabled" : "Disabled")
let color: Color =
!status.exists
? .secondary
: (status.enabled ? .red : .green)
Text("\(prefix) \(statusText)")
.font(.caption2)
.foregroundColor(color)
}
}
}
}
}
private var interfaceIpList: some View {
let statuses = networkStatusService.status.interfaceStatuses
return VStack(alignment: .leading, spacing: 1) {
@@ -465,6 +549,7 @@ struct ContentView: View {
Text(thunderboltStatusText)
.font(.caption2)
.foregroundColor(thunderboltStatusColor)
clusterThunderboltBridgeView
interfaceIpList
rdmaStatusView
sendBugReportButton

View File

@@ -21,6 +21,7 @@ struct EXOApp: App {
@StateObject private var networkStatusService: NetworkStatusService
@StateObject private var localNetworkChecker: LocalNetworkChecker
@StateObject private var updater: SparkleUpdater
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
private let terminationObserver: TerminationObserver
private let ciContext = CIContext(options: nil)
@@ -41,10 +42,13 @@ struct EXOApp: App {
let localNetwork = LocalNetworkChecker()
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
_updater = StateObject(wrappedValue: updater)
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
enableLaunchAtLoginIfNeeded()
NetworkSetupHelper.ensureLaunchDaemonInstalled()
// Check local network access BEFORE launching exo
localNetwork.check()
// Remove old LaunchDaemon components if they exist (from previous versions)
cleanupLegacyNetworkSetup()
// Check local network access periodically (warning disappears when user grants permission)
localNetwork.startPeriodicChecking(interval: 10)
controller.scheduleLaunch(after: 15)
service.startPolling()
networkStatus.startPolling()
@@ -58,6 +62,7 @@ struct EXOApp: App {
.environmentObject(networkStatusService)
.environmentObject(localNetworkChecker)
.environmentObject(updater)
.environmentObject(thunderboltBridgeService)
} label: {
menuBarIcon
}
@@ -130,6 +135,37 @@ struct EXOApp: App {
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
private func cleanupLegacyNetworkSetup() {
guard NetworkSetupHelper.hasInstalledComponents() else { return }
// Dispatch async to ensure app is ready before showing alert
DispatchQueue.main.async {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to configure local network discovery on your device. This requires granting permission once."
alert.alertStyle = .informational
alert.addButton(withTitle: "Continue")
alert.addButton(withTitle: "Later")
let response = alert.runModal()
guard response == .alertFirstButtonReturn else {
Logger().info("User deferred legacy network setup cleanup")
return
}
do {
try NetworkSetupHelper.uninstall()
Logger().info("Cleaned up legacy network setup components")
} catch {
// Non-fatal: user may have cancelled admin prompt or cleanup may have
// partially succeeded. The app will continue normally.
Logger().warning(
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
)
}
}
}
}
/// Helper for managing EXO's launch-at-login registration

View File

@@ -3,6 +3,8 @@ import Combine
import Foundation
private let customNamespaceKey = "EXOCustomNamespace"
private let hfTokenKey = "EXOHFToken"
private let enableImageModelsKey = "EXOEnableImageModels"
@MainActor
final class ExoProcessController: ObservableObject {
@@ -37,6 +39,22 @@ final class ExoProcessController: ObservableObject {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
}
@Published var hfToken: String = {
return UserDefaults.standard.string(forKey: hfTokenKey) ?? ""
}()
{
didSet {
UserDefaults.standard.set(hfToken, forKey: hfTokenKey)
}
}
@Published var enableImageModels: Bool = {
return UserDefaults.standard.bool(forKey: enableImageModelsKey)
}()
{
didSet {
UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)
}
}
private var process: Process?
private var runtimeDirectoryURL: URL?
@@ -191,6 +209,12 @@ final class ExoProcessController: ObservableObject {
var environment = ProcessInfo.processInfo.environment
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
if !hfToken.isEmpty {
environment["HF_TOKEN"] = hfToken
}
if enableImageModels {
environment["EXO_ENABLE_IMAGE_MODELS"] = "true"
}
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {

View File

@@ -5,17 +5,43 @@ import Foundation
struct ClusterState: Decodable {
let instances: [String: ClusterInstance]
let runners: [String: RunnerStatusSummary]
let nodeProfiles: [String: NodeProfile]
let tasks: [String: ClusterTask]
let topology: Topology?
let downloads: [String: [NodeDownloadStatus]]
let thunderboltBridgeCycles: [[String]]
// Granular node state (split from the old nodeProfiles)
let nodeIdentities: [String: NodeIdentity]
let nodeMemory: [String: MemoryInfo]
let nodeSystem: [String: SystemInfo]
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
/// Computed property for backwards compatibility - merges granular state into NodeProfile
var nodeProfiles: [String: NodeProfile] {
var profiles: [String: NodeProfile] = [:]
let allNodeIds = Set(nodeIdentities.keys)
.union(nodeMemory.keys)
.union(nodeSystem.keys)
for nodeId in allNodeIds {
let identity = nodeIdentities[nodeId]
let memory = nodeMemory[nodeId]
let system = nodeSystem[nodeId]
profiles[nodeId] = NodeProfile(
modelId: identity?.modelId,
chipId: identity?.chipId,
friendlyName: identity?.friendlyName,
memory: memory,
system: system
)
}
return profiles
}
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
@@ -24,15 +50,34 @@ struct ClusterState: Decodable {
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
self.thunderboltBridgeCycles =
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
// Granular node state
self.nodeIdentities =
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
?? [:]
self.nodeMemory =
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
self.nodeSystem =
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
self.nodeThunderboltBridge =
try container.decodeIfPresent(
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
) ?? [:]
}
private enum CodingKeys: String, CodingKey {
case instances
case runners
case nodeProfiles
case topology
case tasks
case downloads
case thunderboltBridgeCycles
case nodeIdentities
case nodeMemory
case nodeSystem
case nodeThunderboltBridge
}
}
@@ -102,6 +147,18 @@ struct NodeProfile: Decodable {
let system: SystemInfo?
}
struct NodeIdentity: Decodable {
let modelId: String?
let chipId: String?
let friendlyName: String?
}
struct ThunderboltBridgeStatus: Decodable {
let enabled: Bool
let exists: Bool
let serviceName: String?
}
struct MemoryInfo: Decodable {
let ramTotal: MemoryValue?
let ramAvailable: MemoryValue?
@@ -120,16 +177,51 @@ struct SystemInfo: Decodable {
}
struct Topology: Decodable {
let nodes: [TopologyNode]
let connections: [TopologyConnection]?
/// Node IDs in the topology
let nodes: [String]
/// Flattened list of connections (source -> sink pairs)
let connections: [TopologyConnection]
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
// Connections come as nested map: { source: { sink: [edges] } }
// We flatten to array of (source, sink) pairs
var flatConnections: [TopologyConnection] = []
if let nested = try container.decodeIfPresent(
[String: [String: [AnyCodable]]].self, forKey: .connections
) {
for (source, sinks) in nested {
for sink in sinks.keys {
flatConnections.append(
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
}
}
}
self.connections = flatConnections
}
private enum CodingKeys: String, CodingKey {
case nodes
case connections
}
}
struct TopologyNode: Decodable {
let nodeId: String
let nodeProfile: NodeProfile
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
private struct AnyCodable: Decodable {
init(from decoder: Decoder) throws {
// Just consume the value without storing it
_ = try? decoder.singleValueContainer().decode(Bool.self)
_ = try? decoder.singleValueContainer().decode(Int.self)
_ = try? decoder.singleValueContainer().decode(Double.self)
_ = try? decoder.singleValueContainer().decode(String.self)
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
}
}
struct TopologyConnection: Decodable {
struct TopologyConnection {
let localNodeId: String
let sendBackNodeId: String
}

View File

@@ -55,12 +55,16 @@ struct BugReportService {
let stateData = try await stateResult
let eventsData = try await eventsResult
// Extract cluster TB bridge status from exo state
let clusterTbBridgeStatus = extractClusterTbBridgeStatus(from: stateData)
let reportJSON = makeReportJson(
timestamp: timestamp,
hostName: hostName,
ifconfig: ifconfigText,
debugInfo: debugInfo,
isManual: isManual
isManual: isManual,
clusterTbBridgeStatus: clusterTbBridgeStatus
)
let uploads: [(path: String, data: Data?)] = [
@@ -178,18 +182,19 @@ struct BugReportService {
}
private func readThunderboltBridgeDisabled() -> Bool? {
let result = runCommand([
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased()
if output.contains("enabled") {
return false
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
// No bridge containing Thunderbolt interfaces exists
return nil
}
if output.contains("disabled") {
return true
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
else {
return nil
}
return nil
// Return true if disabled, false if enabled
return !isEnabled
}
private func readInterfaces() -> [DebugInfo.InterfaceStatus] {
@@ -268,11 +273,12 @@ struct BugReportService {
hostName: String,
ifconfig: String,
debugInfo: DebugInfo,
isManual: Bool
isManual: Bool,
clusterTbBridgeStatus: [[String: Any]]?
) -> Data? {
let system = readSystemMetadata()
let exo = readExoMetadata()
let payload: [String: Any] = [
var payload: [String: Any] = [
"timestamp": timestamp,
"host": hostName,
"ifconfig": ifconfig,
@@ -282,9 +288,38 @@ struct BugReportService {
"exo_commit": exo.commit as Any,
"report_type": isManual ? "manual" : "automated",
]
if let tbStatus = clusterTbBridgeStatus {
payload["cluster_thunderbolt_bridge"] = tbStatus
}
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
}
/// Extracts cluster-wide Thunderbolt Bridge status from exo state JSON
private func extractClusterTbBridgeStatus(from stateData: Data?) -> [[String: Any]]? {
guard let data = stateData,
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let nodeThunderboltBridge = json["node_thunderbolt_bridge"] as? [String: [String: Any]]
else {
return nil
}
var result: [[String: Any]] = []
for (nodeId, status) in nodeThunderboltBridge {
var entry: [String: Any] = ["node_id": nodeId]
if let enabled = status["enabled"] as? Bool {
entry["enabled"] = enabled
}
if let exists = status["exists"] as? Bool {
entry["exists"] = exists
}
if let serviceName = status["service_name"] as? String {
entry["service_name"] = serviceName
}
result.append(entry)
}
return result.isEmpty ? nil : result
}
private func readSystemMetadata() -> [String: Any] {
let hostname = safeRunCommand(["/bin/hostname"])
let computerName = safeRunCommand(["/usr/sbin/scutil", "--get", "ComputerName"])

View File

@@ -41,6 +41,7 @@ final class LocalNetworkChecker: ObservableObject {
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
private var periodicTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
@@ -48,10 +49,39 @@ final class LocalNetworkChecker: ObservableObject {
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
/// Checks if local network access is working (one-time check).
func check() {
performCheck()
}
/// Starts periodic checking of local network access.
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
func startPeriodicChecking(interval: TimeInterval = 10) {
stopPeriodicChecking()
// Do an immediate check first
performCheck()
// Then schedule periodic checks
periodicTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
guard !Task.isCancelled else { break }
self?.performCheck()
}
}
}
/// Stops periodic checking.
func stopPeriodicChecking() {
periodicTask?.cancel()
periodicTask = nil
}
private func performCheck() {
checkTask?.cancel()
status = .checking
// Only show "checking" status on first check to avoid UI flicker
if status == .unknown {
status = .checking
}
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
@@ -60,12 +90,15 @@ final class LocalNetworkChecker: ObservableObject {
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
// Only log on state changes or first check to reduce noise
if isFirstCheck || result != self.status {
Self.logger.info("Local network check: \(result.displayText)")
}
}
}
@@ -141,6 +174,7 @@ final class LocalNetworkChecker: ObservableObject {
}
func stop() {
stopPeriodicChecking()
checkTask?.cancel()
checkTask = nil
connection?.cancel()

View File

@@ -7,48 +7,10 @@ enum NetworkSetupHelper {
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
// Legacy script path from older versions
private static let legacyScriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
do {
if daemonAlreadyInstalled() {
return
}
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
@@ -63,8 +25,9 @@ enum NetworkSetupHelper {
static func hasInstalledComponents() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
return scriptExists || plistExists
return scriptExists || legacyScriptExists || plistExists
}
private static func makeUninstallScript() -> String {
@@ -73,6 +36,7 @@ enum NetworkSetupHelper {
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
@@ -83,8 +47,9 @@ enum NetworkSetupHelper {
# Remove LaunchDaemon plist
rm -f "$PLIST_DEST"
# Remove the script and parent directory if empty
# Remove the script (current and legacy paths) and parent directory if empty
rm -f "$SCRIPT_DEST"
rm -f "$LEGACY_SCRIPT_DEST"
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
# Remove log files
@@ -98,99 +63,42 @@ enum NetworkSetupHelper {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable Thunderbolt Bridge if it exists
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
# We find it dynamically by looking for bridges containing Thunderbolt interfaces
find_and_enable_thunderbolt_bridge() {
# Get Thunderbolt interface devices from hardware ports
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
')
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
if echo "$members" | grep -qx "$tb_dev"; then
# Find the service name for this bridge device
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
')
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
fi
fi
done
done
}
find_and_enable_thunderbolt_bridge
echo "EXO network components removed successfully"
"""
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
guard
let interval = plist["StartInterval"] as? Int,
interval == requiredStartInterval
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
}
private static func installLaunchDaemon() async throws {
let installerScript = makeInstallerScript()
try runShellAsAdmin(installerScript)
}
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript =
script

View File

@@ -153,22 +153,18 @@ private struct NetworkStatusFetcher {
}
private func readThunderboltBridgeState() -> ThunderboltState? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
guard result.exitCode == 0 else {
let lower = result.output.lowercased() + result.error.lowercased()
if lower.contains("not a recognized network service") {
return .deleted
}
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
// No bridge containing Thunderbolt interfaces exists
return .deleted
}
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
else {
return nil
}
let output = result.output.lowercased()
if output.contains("enabled") {
return .enabled
}
if output.contains("disabled") {
return .disabled
}
return nil
return isEnabled ? .enabled : .disabled
}
private func readBridgeInactive() -> Bool? {

View File

@@ -0,0 +1,194 @@
import Foundation
import os.log
/// Utility for dynamically detecting Thunderbolt Bridge network services.
/// This mirrors the Python logic in info_gatherer.py - we never assume the service
/// is named "Thunderbolt Bridge", instead we find bridges containing Thunderbolt interfaces.
enum ThunderboltBridgeDetector {
private static let logger = Logger(
subsystem: "io.exo.EXO", category: "ThunderboltBridgeDetector")
struct CommandResult {
let exitCode: Int32
let output: String
let error: String
}
/// Find the network service name of a bridge containing Thunderbolt interfaces.
/// Returns nil if no such bridge exists.
static func findThunderboltBridgeServiceName() -> String? {
// 1. Get all Thunderbolt interface devices (e.g., en2, en3)
guard let thunderboltDevices = getThunderboltDevices(), !thunderboltDevices.isEmpty else {
logger.debug("No Thunderbolt devices found")
return nil
}
logger.debug("Found Thunderbolt devices: \(thunderboltDevices.joined(separator: ", "))")
// 2. Get bridge services from network service order
guard let bridgeServices = getBridgeServices(), !bridgeServices.isEmpty else {
logger.debug("No bridge services found")
return nil
}
logger.debug("Found bridge services: \(bridgeServices.keys.joined(separator: ", "))")
// 3. Find a bridge that contains Thunderbolt interfaces
for (bridgeDevice, serviceName) in bridgeServices {
let members = getBridgeMembers(bridgeDevice: bridgeDevice)
logger.debug(
"Bridge \(bridgeDevice) (\(serviceName)) has members: \(members.joined(separator: ", "))"
)
// Check if any Thunderbolt device is a member of this bridge
if !members.isDisjoint(with: thunderboltDevices) {
logger.info(
"Found Thunderbolt Bridge service: '\(serviceName)' (device: \(bridgeDevice))")
return serviceName
}
}
logger.debug("No bridge found containing Thunderbolt interfaces")
return nil
}
/// Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
private static func getThunderboltDevices() -> Set<String>? {
let result = runCommand(["networksetup", "-listallhardwareports"])
guard result.exitCode == 0 else {
logger.warning("networksetup -listallhardwareports failed: \(result.error)")
return nil
}
var thunderboltDevices: Set<String> = []
var currentPort: String?
for line in result.output.components(separatedBy: .newlines) {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("Hardware Port:") {
currentPort = String(trimmed.dropFirst("Hardware Port:".count)).trimmingCharacters(
in: .whitespaces)
} else if trimmed.hasPrefix("Device:"), let port = currentPort {
let device = String(trimmed.dropFirst("Device:".count)).trimmingCharacters(
in: .whitespaces)
if port.lowercased().contains("thunderbolt") {
thunderboltDevices.insert(device)
}
currentPort = nil
}
}
return thunderboltDevices
}
/// Get mapping of bridge device -> service name from network service order.
private static func getBridgeServices() -> [String: String]? {
let result = runCommand(["networksetup", "-listnetworkserviceorder"])
guard result.exitCode == 0 else {
logger.warning("networksetup -listnetworkserviceorder failed: \(result.error)")
return nil
}
// Parse service order to find bridge devices and their service names
// Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
var bridgeServices: [String: String] = [:]
var currentService: String?
for line in result.output.components(separatedBy: .newlines) {
let trimmed = line.trimmingCharacters(in: .whitespaces)
// Match "(N) Service Name" or "(*) Service Name" (disabled)
// but NOT "(Hardware Port: ...)" lines
if trimmed.hasPrefix("("), trimmed.contains(")"),
!trimmed.hasPrefix("(Hardware Port:")
{
if let parenEnd = trimmed.firstIndex(of: ")") {
let afterParen = trimmed.index(after: parenEnd)
if afterParen < trimmed.endIndex {
currentService =
String(trimmed[afterParen...])
.trimmingCharacters(in: .whitespaces)
}
}
}
// Match "(Hardware Port: ..., Device: bridgeX)"
else if let service = currentService, trimmed.contains("Device: bridge") {
// Extract device name from "..., Device: bridge0)"
if let deviceRange = trimmed.range(of: "Device: ") {
let afterDevice = trimmed[deviceRange.upperBound...]
if let parenIndex = afterDevice.firstIndex(of: ")") {
let device = String(afterDevice[..<parenIndex])
bridgeServices[device] = service
}
}
}
}
return bridgeServices
}
/// Get member interfaces of a bridge device via ifconfig.
private static func getBridgeMembers(bridgeDevice: String) -> Set<String> {
let result = runCommand(["ifconfig", bridgeDevice])
guard result.exitCode == 0 else {
logger.debug("ifconfig \(bridgeDevice) failed")
return []
}
var members: Set<String> = []
for line in result.output.components(separatedBy: .newlines) {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("member:") {
let parts = trimmed.split(separator: " ")
if parts.count > 1 {
members.insert(String(parts[1]))
}
}
}
return members
}
/// Check if a network service is enabled.
static func isServiceEnabled(serviceName: String) -> Bool? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", serviceName])
guard result.exitCode == 0 else {
logger.warning("Failed to check if '\(serviceName)' is enabled: \(result.error)")
return nil
}
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
if output.contains("enabled") {
return true
}
if output.contains("disabled") {
return false
}
return nil
}
private static func runCommand(_ arguments: [String]) -> CommandResult {
let process = Process()
process.launchPath = "/usr/bin/env"
process.arguments = arguments
let stdout = Pipe()
let stderr = Pipe()
process.standardOutput = stdout
process.standardError = stderr
do {
try process.run()
} catch {
return CommandResult(exitCode: -1, output: "", error: error.localizedDescription)
}
process.waitUntilExit()
let outputData = stdout.fileHandleForReading.readDataToEndOfFile()
let errorData = stderr.fileHandleForReading.readDataToEndOfFile()
return CommandResult(
exitCode: process.terminationStatus,
output: String(decoding: outputData, as: UTF8.self),
error: String(decoding: errorData, as: UTF8.self)
)
}
}

View File

@@ -0,0 +1,258 @@
import AppKit
import Combine
import Foundation
import Security
import SystemConfiguration
import os.log
@MainActor
final class ThunderboltBridgeService: ObservableObject {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
@Published private(set) var detectedCycle: [String]?
@Published private(set) var hasPromptedForCurrentCycle = false
@Published private(set) var lastError: String?
private weak var clusterStateService: ClusterStateService?
private var cancellables = Set<AnyCancellable>()
private var previousCycleSignature: String?
init(clusterStateService: ClusterStateService) {
self.clusterStateService = clusterStateService
setupObserver()
}
private func setupObserver() {
guard let service = clusterStateService else { return }
service.$latestSnapshot
.compactMap { $0 }
.sink { [weak self] snapshot in
self?.checkForCycles(snapshot: snapshot)
}
.store(in: &cancellables)
}
private func checkForCycles(snapshot: ClusterState) {
let cycles = snapshot.thunderboltBridgeCycles
// Only consider cycles with more than 2 nodes
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
// No problematic cycles detected, reset state
if detectedCycle != nil {
detectedCycle = nil
previousCycleSignature = nil
hasPromptedForCurrentCycle = false
}
return
}
// Create a signature for this cycle to detect if it changed
let cycleSignature = firstCycle.sorted().joined(separator: ",")
// If this is a new/different cycle, reset the prompt state
if cycleSignature != previousCycleSignature {
previousCycleSignature = cycleSignature
hasPromptedForCurrentCycle = false
}
detectedCycle = firstCycle
// Only prompt once per cycle
if !hasPromptedForCurrentCycle {
showDisableBridgePrompt(nodeIds: firstCycle)
}
}
private func showDisableBridgePrompt(nodeIds: [String]) {
hasPromptedForCurrentCycle = true
// Get friendly names for the nodes if available
let nodeNames = nodeIds.map { nodeId -> String in
if let snapshot = clusterStateService?.latestSnapshot,
let profile = snapshot.nodeProfiles[nodeId],
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
{
return friendlyName
}
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
}
let machineNames = nodeNames.joined(separator: ", ")
let alert = NSAlert()
alert.messageText = "Thunderbolt Bridge Loop Detected"
alert.informativeText = """
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Disable Bridge")
alert.addButton(withTitle: "Not Now")
let response = alert.runModal()
if response == .alertFirstButtonReturn {
Task {
await disableThunderboltBridge()
}
}
}
func disableThunderboltBridge() async {
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
lastError = nil
do {
try await disableThunderboltBridgeWithSCPreferences()
Self.logger.info("Successfully disabled Thunderbolt Bridge")
} catch {
Self.logger.error(
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
)
lastError = error.localizedDescription
showErrorAlert(message: error.localizedDescription)
}
}
private func disableThunderboltBridgeWithSCPreferences() async throws {
// 1. Create authorization reference
var authRef: AuthorizationRef?
var status = AuthorizationCreate(nil, nil, [], &authRef)
guard status == errAuthorizationSuccess, let authRef = authRef else {
throw ThunderboltBridgeError.authorizationFailed
}
defer { AuthorizationFree(authRef, [.destroyRights]) }
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled
}
throw ThunderboltBridgeError.authorizationDenied
}
// 3. Create SCPreferences with authorization
guard
let prefs = SCPreferencesCreateWithAuthorization(
kCFAllocatorDefault,
"EXO" as CFString,
nil,
authRef
)
else {
throw ThunderboltBridgeError.preferencesCreationFailed
}
// 4. Lock, modify, commit
guard SCPreferencesLock(prefs, true) else {
throw ThunderboltBridgeError.lockFailed
}
defer {
SCPreferencesUnlock(prefs)
}
// 5. Find the Thunderbolt Bridge service dynamically (don't assume the name)
guard let targetServiceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName()
else {
throw ThunderboltBridgeError.serviceNotFound
}
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
throw ThunderboltBridgeError.servicesNotFound
}
var found = false
for service in allServices {
if let name = SCNetworkServiceGetName(service) as String?,
name == targetServiceName
{
guard SCNetworkServiceSetEnabled(service, false) else {
throw ThunderboltBridgeError.disableFailed
}
found = true
Self.logger.info(
"Found and disabled Thunderbolt Bridge service: '\(targetServiceName)'")
break
}
}
if !found {
throw ThunderboltBridgeError.serviceNotFound
}
// 6. Commit and apply
guard SCPreferencesCommitChanges(prefs) else {
throw ThunderboltBridgeError.commitFailed
}
guard SCPreferencesApplyChanges(prefs) else {
throw ThunderboltBridgeError.applyFailed
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Failed to Disable Thunderbolt Bridge"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
}
enum ThunderboltBridgeError: LocalizedError {
case authorizationFailed
case authorizationCanceled
case authorizationDenied
case preferencesCreationFailed
case lockFailed
case servicesNotFound
case serviceNotFound
case disableFailed
case commitFailed
case applyFailed
var errorDescription: String? {
switch self {
case .authorizationFailed:
return "Failed to create authorization"
case .authorizationCanceled:
return "Authorization was canceled by user"
case .authorizationDenied:
return "Authorization was denied"
case .preferencesCreationFailed:
return "Failed to access network preferences"
case .lockFailed:
return "Failed to lock network preferences for modification"
case .servicesNotFound:
return "Could not retrieve network services"
case .serviceNotFound:
return "Thunderbolt Bridge service not found"
case .disableFailed:
return "Failed to disable Thunderbolt Bridge service"
case .commitFailed:
return "Failed to save network configuration changes"
case .applyFailed:
return "Failed to apply network configuration changes"
}
}
}

View File

@@ -86,7 +86,7 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let topologyNodeIds = Set(topology?.nodes ?? [])
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
@@ -95,8 +95,8 @@ extension ClusterState {
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
var orderedNodes: [NodeViewModel] = []
if let topologyNodes = topology?.nodes {
for topoNode in topologyNodes {
if let viewModel = nodesById[topoNode.nodeId] {
for nodeId in topologyNodes {
if let viewModel = nodesById[nodeId] {
orderedNodes.append(viewModel)
}
}
@@ -116,7 +116,7 @@ extension ClusterState {
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections?.compactMap { connection in
topology?.connections.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }

View File

@@ -865,6 +865,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -904,6 +905,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,6 +1522,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1529,6 +1532,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,6 +1945,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2648,6 +2653,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2690,6 +2696,7 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2862,6 +2869,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3006,6 +3014,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3027,6 +3036,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -3,6 +3,45 @@
perSystem =
{ pkgs, lib, ... }:
let
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src
touch $out/src/app.html
'';
# Deps-only build using stub source (for prettier-svelte)
# Only rebuilds when package.json or package-lock.json change
dashboardDeps = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
{
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
}
# Override build phases to skip the actual build - just need node_modules
{
mkDerivation = {
buildPhase = lib.mkForce "true";
installPhase = lib.mkForce ''
runHook preInstall
runHook postInstall
'';
};
}
];
};
# Filter source to only include dashboard directory
dashboardSrc = lib.cleanSourceWith {
src = inputs.self;
@@ -42,11 +81,12 @@
'';
# Prettier with svelte plugin for treefmt
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
exec ${pkgs.nodejs}/bin/node \
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
"$@"
'';
};

View File

@@ -89,7 +89,10 @@
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsTextToImage(currentModel);
return (
modelSupportsTextToImage(currentModel) ||
modelSupportsImageEditing(currentModel)
);
});
const isEditOnlyWithoutImage = $derived(
@@ -646,6 +649,23 @@
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg

View File

@@ -110,6 +110,36 @@
setImageGenerationParams({ negativePrompt: value || null });
}
function handleNumImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ numImages: 1 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 1) {
setImageGenerationParams({ numImages: num });
}
}
}
function handleStreamChange(enabled: boolean) {
setImageGenerationParams({ stream: enabled });
}
function handlePartialImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ partialImages: 0 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 0) {
setImageGenerationParams({ partialImages: num });
}
}
}
function clearSteps() {
setImageGenerationParams({ numInferenceSteps: null });
}
@@ -134,90 +164,92 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{/if}
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -325,6 +357,59 @@
</div>
</div>
<!-- Number of Images (not in edit mode) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>IMAGES:</span
>
<input
type="number"
min="1"
value={params.numImages}
oninput={handleNumImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Stream toggle -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>STREAM:</span
>
<button
type="button"
onclick={() => handleStreamChange(!params.stream)}
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
? 'bg-exo-yellow'
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
>
<div
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
? 'right-0.5 bg-exo-black'
: 'left-0.5 bg-exo-light-gray'}"
></div>
</button>
</div>
<!-- Partial Images (only when streaming) -->
{#if params.stream}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>PARTIALS:</span
>
<input
type="number"
min="0"
value={params.partialImages}
oninput={handlePartialImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Input Fidelity (edit mode only) -->
{#if isEditMode}
<div class="flex items-center gap-1.5">

View File

@@ -5,22 +5,32 @@
topologyData,
isTopologyMinimized,
debugMode,
nodeThunderboltBridge,
type NodeInfo,
} from "$lib/stores/app.svelte";
interface Props {
class?: string;
highlightedNodes?: Set<string>;
filteredNodes?: Set<string>;
onNodeClick?: (nodeId: string) => void;
}
let { class: className = "", highlightedNodes = new Set() }: Props = $props();
let {
class: className = "",
highlightedNodes = new Set(),
filteredNodes = new Set(),
onNodeClick,
}: Props = $props();
let svgContainer: SVGSVGElement | undefined = $state();
let resizeObserver: ResizeObserver | undefined;
let hoveredNodeId = $state<string | null>(null);
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -522,10 +532,72 @@
}
}
let iconBaseWidth = nodeRadius * 1.2;
let iconBaseHeight = nodeRadius * 1.0;
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
const modelLower = modelId.toLowerCase();
// Check node states for styling
const isHighlighted = highlightedNodes.has(nodeInfo.id);
const isInFilter =
filteredNodes.size > 0 && filteredNodes.has(nodeInfo.id);
const isFilteredOut =
filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id);
const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter;
// Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out
const wireColor = isInFilter
? "rgba(255,215,0,1)" // Bright yellow for filter selection
: isHovered
? "rgba(255,215,0,0.7)" // Subtle yellow for hover
: isHighlighted
? "rgba(255,215,0,0.9)" // Yellow for instance highlight
: isFilteredOut
? "rgba(140,140,140,0.6)" // Grey for filtered out
: "rgba(179,179,179,0.8)"; // Default
const wireColorBright = "rgba(255,255,255,0.9)";
const fillColor = isInFilter
? "rgba(255,215,0,0.25)"
: isHovered
? "rgba(255,215,0,0.12)"
: isHighlighted
? "rgba(255,215,0,0.15)"
: "rgba(255,215,0,0.08)";
const strokeWidth = isInFilter
? 3
: isHovered
? 2
: isHighlighted
? 2.5
: 1.5;
const screenFill = "rgba(0,20,40,0.9)";
const glowColor = "rgba(255,215,0,0.3)";
const nodeG = nodesGroup
.append("g")
.attr("class", "graph-node")
.style("cursor", "pointer");
.style("cursor", onNodeClick ? "pointer" : "default")
.style("opacity", isFilteredOut ? 0.5 : 1);
// Add click and hover handlers - hover just updates state, styling is applied during render
nodeG
.on("click", (event: MouseEvent) => {
if (onNodeClick) {
event.stopPropagation();
onNodeClick(nodeInfo.id);
}
})
.on("mouseenter", () => {
if (onNodeClick) {
hoveredNodeId = nodeInfo.id;
}
})
.on("mouseleave", () => {
if (hoveredNodeId === nodeInfo.id) {
hoveredNodeId = null;
}
});
// Add tooltip
nodeG
@@ -534,27 +606,6 @@
`${friendlyName}\nID: ${nodeInfo.id.slice(-8)}\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`,
);
let iconBaseWidth = nodeRadius * 1.2;
let iconBaseHeight = nodeRadius * 1.0;
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
const modelLower = modelId.toLowerCase();
// Check if this node should be highlighted (from hovered instance)
const isHighlighted = highlightedNodes.has(nodeInfo.id);
// Holographic wireframe colors - yellow border when highlighted
const wireColor = isHighlighted
? "rgba(255,215,0,0.9)"
: "rgba(179,179,179,0.8)";
const wireColorBright = "rgba(255,255,255,0.9)";
const fillColor = isHighlighted
? "rgba(255,215,0,0.15)"
: "rgba(255,215,0,0.08)";
const strokeWidth = isHighlighted ? 2.5 : 1.5;
const screenFill = "rgba(0,20,40,0.9)";
const glowColor = "rgba(255,215,0,0.3)";
if (modelLower === "mac studio") {
// Mac Studio - classic cube with memory fill
iconBaseWidth = nodeRadius * 1.25;
@@ -579,6 +630,7 @@
// Main body (uniform color)
nodeG
.append("rect")
.attr("class", "node-outline")
.attr("x", x)
.attr("y", y)
.attr("width", iconBaseWidth)
@@ -661,6 +713,7 @@
// Main body (uniform color)
nodeG
.append("rect")
.attr("class", "node-outline")
.attr("x", x)
.attr("y", y)
.attr("width", iconBaseWidth)
@@ -738,6 +791,7 @@
// Screen outer frame
nodeG
.append("rect")
.attr("class", "node-outline")
.attr("x", screenX)
.attr("y", y)
.attr("width", screenWidth)
@@ -846,6 +900,7 @@
// Main shape
nodeG
.append("polygon")
.attr("class", "node-outline")
.attr("points", hexPoints)
.attr("fill", fillColor)
.attr("stroke", wireColor)
@@ -1064,11 +1119,41 @@
.attr("fill", "rgba(179,179,179,0.7)")
.text(` (${ramUsagePercent.toFixed(0)}%)`);
}
// Debug mode: Show TB bridge status
if (debugEnabled) {
const tbStatus = tbBridgeData[nodeInfo.id];
if (tbStatus) {
const tbY =
nodeInfo.y +
iconBaseHeight / 2 +
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
const tbFontSize = showFullLabels ? 9 : 7;
const tbColor = tbStatus.enabled
? "rgba(234,179,8,0.9)"
: "rgba(100,100,100,0.7)";
const tbText = tbStatus.enabled ? "TB:ON" : "TB:OFF";
nodeG
.append("text")
.attr("x", nodeInfo.x)
.attr("y", tbY)
.attr("text-anchor", "middle")
.attr("fill", tbColor)
.attr("font-size", tbFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(tbText);
}
}
});
}
$effect(() => {
if (data) {
// Track all reactive dependencies that affect rendering
const _data = data;
const _hoveredNodeId = hoveredNodeId;
const _filteredNodes = filteredNodes;
const _highlightedNodes = highlightedNodes;
if (_data) {
renderGraph();
}
});
@@ -1091,12 +1176,8 @@
<style>
:global(.graph-node) {
transition:
transform 0.2s ease,
opacity 0.2s ease;
}
:global(.graph-node:hover) {
filter: brightness(1.1);
/* Only transition opacity for filtered-out nodes, no transition on hover stroke changes */
transition: opacity 0.2s ease;
}
:global(.graph-link) {
stroke: var(--exo-light-gray, #b3b3b3);

View File

@@ -190,6 +190,13 @@ interface RawStateResponse {
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
// Thunderbolt bridge status per node
nodeThunderboltBridge?: Record<
string,
{ enabled: boolean; exists: boolean; serviceName?: string | null }
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
}
export interface MessageAttachment {
@@ -209,6 +216,8 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
requestType?: "chat" | "image-generation" | "image-editing";
sourceImageDataUrl?: string; // For image editing regeneration
}
export interface Conversation {
@@ -231,6 +240,10 @@ export interface ImageGenerationParams {
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
// Streaming params
stream: boolean;
partialImages: number;
// Advanced params
seed: number | null;
numInferenceSteps: number | null;
@@ -250,6 +263,9 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
quality: "medium",
outputFormat: "png",
numImages: 1,
stream: true,
partialImages: 3,
seed: null,
numInferenceSteps: null,
guidance: null,
@@ -419,7 +435,15 @@ class AppStore {
placementPreviews = $state<PlacementPreview[]>([]);
selectedPreviewModelId = $state<string | null>(null);
isLoadingPreviews = $state(false);
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderboltBridge = $state<
Record<
string,
{ enabled: boolean; exists: boolean; serviceName?: string | null }
>
>({});
// UI state
isTopologyMinimized = $state(false);
@@ -439,6 +463,7 @@ class AppStore {
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private previousNodeIds: Set<string> = new Set();
constructor() {
if (browser) {
@@ -997,6 +1022,8 @@ class AppStore {
nodeSystem: data.nodeSystem,
nodeNetwork: data.nodeNetwork,
});
// Handle topology changes for preview filter
this.handleTopologyChange();
}
if (data.instances) {
this.instances = data.instances;
@@ -1008,6 +1035,10 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
// Thunderbolt bridge cycles
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
// Thunderbolt bridge status per node
this.nodeThunderboltBridge = data.nodeThunderboltBridge ?? {};
this.lastUpdate = Date.now();
} catch (error) {
console.error("Error fetching state:", error);
@@ -1023,9 +1054,14 @@ class AppStore {
this.selectedPreviewModelId = modelId;
try {
const response = await fetch(
`/instance/previews?model_id=${encodeURIComponent(modelId)}`,
);
let url = `/instance/previews?model_id=${encodeURIComponent(modelId)}`;
// Add node filter if active
if (this.previewNodeFilter.size > 0) {
for (const nodeId of this.previewNodeFilter) {
url += `&node_ids=${encodeURIComponent(nodeId)}`;
}
}
const response = await fetch(url);
if (!response.ok) {
throw new Error(
`Failed to fetch placement previews: ${response.status}`,
@@ -1075,6 +1111,71 @@ class AppStore {
}
}
/**
* Toggle a node in the preview filter and re-fetch placements
*/
togglePreviewNodeFilter(nodeId: string) {
const next = new Set(this.previewNodeFilter);
if (next.has(nodeId)) {
next.delete(nodeId);
} else {
next.add(nodeId);
}
this.previewNodeFilter = next;
// Re-fetch with new filter if we have a selected model
if (this.selectedPreviewModelId) {
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
}
}
/**
* Clear the preview node filter and re-fetch placements
*/
clearPreviewNodeFilter() {
this.previewNodeFilter = new Set();
// Re-fetch with no filter if we have a selected model
if (this.selectedPreviewModelId) {
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
}
}
/**
* Handle topology changes - clean up filter and re-fetch if needed
*/
private handleTopologyChange() {
if (!this.topologyData) return;
const currentNodeIds = new Set(Object.keys(this.topologyData.nodes));
// Check if nodes have changed
const nodesAdded = [...currentNodeIds].some(
(id) => !this.previousNodeIds.has(id),
);
const nodesRemoved = [...this.previousNodeIds].some(
(id) => !currentNodeIds.has(id),
);
if (nodesAdded || nodesRemoved) {
// Clean up filter - remove any nodes that no longer exist
if (this.previewNodeFilter.size > 0) {
const validFilterNodes = new Set(
[...this.previewNodeFilter].filter((id) => currentNodeIds.has(id)),
);
if (validFilterNodes.size !== this.previewNodeFilter.size) {
this.previewNodeFilter = validFilterNodes;
}
}
// Re-fetch previews if we have a selected model (topology changed)
if (this.selectedPreviewModelId) {
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
}
}
// Update tracked node IDs for next comparison
this.previousNodeIds = currentNodeIds;
}
/**
* Starts a chat conversation - triggers the topology minimization animation
* Creates a new conversation if none is active
@@ -1171,10 +1272,46 @@ class AppStore {
if (lastUserIndex === -1) return;
// Remove any messages after the user message
this.messages = this.messages.slice(0, lastUserIndex + 1);
const lastUserMessage = this.messages[lastUserIndex];
const requestType = lastUserMessage.requestType || "chat";
const prompt = lastUserMessage.content;
// Resend the message to get a new response
// Remove messages after user message (including the user message for image requests
// since generateImage/editImage will re-add it)
this.messages = this.messages.slice(0, lastUserIndex);
switch (requestType) {
case "image-generation":
await this.generateImage(prompt);
break;
case "image-editing":
if (lastUserMessage.sourceImageDataUrl) {
await this.editImage(prompt, lastUserMessage.sourceImageDataUrl);
} else {
// Can't regenerate edit without source image - restore user message and show error
this.messages.push(lastUserMessage);
const errorMessage = this.addMessage("assistant", "");
const idx = this.messages.findIndex((m) => m.id === errorMessage.id);
if (idx !== -1) {
this.messages[idx].content =
"Error: Cannot regenerate image edit - source image not found";
}
this.updateActiveConversation();
}
break;
case "chat":
default:
// Restore the user message for chat regeneration
this.messages.push(lastUserMessage);
await this.regenerateChatCompletion();
break;
}
}
/**
* Helper method to regenerate a chat completion response
*/
private async regenerateChatCompletion(): Promise<void> {
this.isLoading = true;
this.currentResponse = "";
@@ -1689,6 +1826,7 @@ class AppStore {
role: "user",
content: prompt,
timestamp: Date.now(),
requestType: "image-generation",
};
this.messages.push(userMessage);
@@ -1717,12 +1855,13 @@ class AppStore {
const requestBody: Record<string, unknown> = {
model,
prompt,
n: params.numImages,
quality: params.quality,
size: params.size,
output_format: params.outputFormat,
response_format: "b64_json",
stream: true,
partial_images: 3,
stream: params.stream,
partial_images: params.partialImages,
};
if (hasAdvancedParams) {
@@ -1786,31 +1925,74 @@ class AppStore {
if (imageData && idx !== -1) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
const numImages = params.numImages;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.messages[idx].content =
`Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].content = progressText;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
this.messages[idx].attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
this.messages[idx].attachments = [
...finals,
partialAttachment,
];
}
} else if (parsed.type === "final") {
// Final image
this.messages[idx].content = "";
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First final image - replace any partial preview
this.messages[idx].attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
this.messages[idx].attachments = [
...previousFinals,
newAttachment,
];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
this.messages[idx].content =
`Generating image ${imageIndex + 2}/${numImages}...`;
} else {
this.messages[idx].content = "";
}
}
}
} catch {
@@ -1855,6 +2037,8 @@ class AppStore {
role: "user",
content: prompt,
timestamp: Date.now(),
requestType: "image-editing",
sourceImageDataUrl: imageDataUrl,
};
this.messages.push(userMessage);
@@ -1891,8 +2075,8 @@ class AppStore {
formData.append("size", params.size);
formData.append("output_format", params.outputFormat);
formData.append("response_format", "b64_json");
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
formData.append("partial_images", "3");
formData.append("stream", params.stream ? "1" : "0");
formData.append("partial_images", params.partialImages.toString());
formData.append("input_fidelity", params.inputFidelity);
// Advanced params
@@ -2044,6 +2228,54 @@ class AppStore {
this.conversations.find((c) => c.id === this.activeConversationId) || null
);
}
/**
* Start a download on a specific node
*/
async startDownload(nodeId: string, shardMetadata: object): Promise<void> {
try {
const response = await fetch("/download/start", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
targetNodeId: nodeId,
shardMetadata: shardMetadata,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Failed to start download: ${response.status} - ${errorText}`,
);
}
} catch (error) {
console.error("Error starting download:", error);
throw error;
}
}
/**
* Delete a downloaded model from a specific node
*/
async deleteDownload(nodeId: string, modelId: string): Promise<void> {
try {
const response = await fetch(
`/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`,
{
method: "DELETE",
},
);
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Failed to delete download: ${response.status} - ${errorText}`,
);
}
} catch (error) {
console.error("Error deleting download:", error);
throw error;
}
}
}
export const appStore = new AppStore();
@@ -2098,6 +2330,10 @@ export const setSelectedChatModel = (modelId: string) =>
appStore.setSelectedModel(modelId);
export const selectPreviewModel = (modelId: string | null) =>
appStore.selectPreviewModel(modelId);
export const togglePreviewNodeFilter = (nodeId: string) =>
appStore.togglePreviewNodeFilter(nodeId);
export const clearPreviewNodeFilter = () => appStore.clearPreviewNodeFilter();
export const previewNodeFilter = () => appStore.previewNodeFilter;
export const deleteMessage = (messageId: string) =>
appStore.deleteMessage(messageId);
export const editMessage = (messageId: string, newContent: string) =>
@@ -2134,6 +2370,10 @@ export const setChatSidebarVisible = (visible: boolean) =>
appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();
// Thunderbolt bridge status
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
// Image generation params
export const imageGenerationParams = () => appStore.getImageGenerationParams();
export const setImageGenerationParams = (
@@ -2141,3 +2381,9 @@ export const setImageGenerationParams = (
) => appStore.setImageGenerationParams(params);
export const resetImageGenerationParams = () =>
appStore.resetImageGenerationParams();
// Download actions
export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);

View File

@@ -19,6 +19,9 @@
selectedPreviewModelId,
isLoadingPreviews,
selectPreviewModel,
togglePreviewNodeFilter,
clearPreviewNodeFilter,
previewNodeFilter,
createConversation,
setSelectedChatModel,
selectedChatModel,
@@ -28,6 +31,8 @@
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
thunderboltBridgeCycles,
nodeThunderboltBridge,
type DownloadProgress,
type PlacementPreview,
} from "$lib/stores/app.svelte";
@@ -49,6 +54,41 @@
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const sidebarVisible = $derived(chatSidebarVisible());
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
const tbBridgeData = $derived(nodeThunderboltBridge());
const nodeFilter = $derived(previewNodeFilter());
// Helper to get friendly node name from node ID
function getNodeName(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8) + "...";
}
// Helper to get the thunderbolt bridge service name from a cycle
function getTbBridgeServiceName(cycle: string[]): string {
// Try to find service name from any node in the cycle
for (const nodeId of cycle) {
const nodeData = tbBridgeData?.[nodeId];
if (nodeData?.serviceName) {
return nodeData.serviceName;
}
}
return "Thunderbolt Bridge"; // Fallback if no service name found
}
// Copy to clipboard state and function
let copiedCommand = $state(false);
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text);
copiedCommand = true;
setTimeout(() => {
copiedCommand = false;
}, 2000);
} catch (err) {
console.error("Failed to copy:", err);
}
}
let mounted = $state(false);
@@ -90,6 +130,15 @@
model.tasks.includes("ImageToImage")
);
}
// Helper to check if a model supports image editing
function modelSupportsImageEditing(modelId: string): boolean {
const model = models.find(
(m) => m.id === modelId || m.hugging_face_id === modelId,
);
if (!model?.tasks) return false;
return model.tasks.includes("ImageToImage");
}
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
@@ -181,6 +230,9 @@
// Preview card hover state for highlighting nodes in topology
let hoveredPreviewNodes = $state<Set<string>>(new Set());
// Computed: Check if filter is active (from store)
const isFilterActive = $derived(() => nodeFilter.size > 0);
// Helper to unwrap tagged instance for hover highlighting
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
if (!instanceWrapped || typeof instanceWrapped !== "object")
@@ -732,6 +784,8 @@
instanceWrapped: unknown,
): {
isDownloading: boolean;
isFailed: boolean;
errorMessage: string | null;
progress: DownloadProgress | null;
statusText: string;
perNode: Array<{
@@ -743,6 +797,8 @@
if (!downloadsData || Object.keys(downloadsData).length === 0) {
return {
isDownloading: false,
isFailed: false,
errorMessage: null,
progress: null,
statusText: "RUNNING",
perNode: [],
@@ -754,6 +810,8 @@
if (!instance || typeof instance !== "object") {
return {
isDownloading: false,
isFailed: false,
errorMessage: null,
progress: null,
statusText: "PREPARING",
perNode: [],
@@ -809,6 +867,26 @@
downloadKind
] as Record<string, unknown>;
// Handle DownloadFailed - return immediately with error info
if (downloadKind === "DownloadFailed") {
const downloadModelId = extractModelIdFromDownload(downloadPayload);
if (
instanceModelId &&
downloadModelId &&
downloadModelId === instanceModelId
) {
return {
isDownloading: false,
isFailed: true,
errorMessage:
(downloadPayload.errorMessage as string) || "Download failed",
progress: null,
statusText: "FAILED",
perNode: [],
};
}
}
if (downloadKind !== "DownloadOngoing") continue;
if (!downloadPayload) continue;
@@ -844,6 +922,8 @@
const statusInfo = deriveInstanceStatus(instanceWrapped);
return {
isDownloading: false,
isFailed: statusInfo.statusText === "FAILED",
errorMessage: null,
progress: null,
statusText: statusInfo.statusText,
perNode: [],
@@ -856,6 +936,8 @@
return {
isDownloading: true,
isFailed: false,
errorMessage: null,
progress: {
totalBytes,
downloadedBytes,
@@ -1451,7 +1533,7 @@
// Get ALL filtered previews based on current settings (matching minimum nodes)
// Note: previewsData already contains previews for the selected model (fetched via API)
// We filter by sharding/instance type and min nodes, returning ALL eligible previews
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
const filteredPreviews = $derived(() => {
if (!selectedModelId || previewsData.length === 0) return [];
@@ -1584,7 +1666,86 @@
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div class="absolute top-4 left-4 group" role="alert">
<div
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
>
<svg
class="w-5 h-5 text-yellow-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<span class="text-sm font-mono text-yellow-200">
THUNDERBOLT BRIDGE CYCLE DETECTED
</span>
</div>
<!-- Tooltip on hover -->
<div
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
>
<p class="text-xs text-white/80 mb-2">
A network routing cycle was detected between nodes connected
via Thunderbolt Bridge. This can cause connectivity issues.
</p>
<p class="text-xs text-white/60 mb-2">
<span class="text-yellow-300">Affected nodes:</span>
{cycle.map(getNodeName).join(" → ")}
</p>
<p class="text-xs text-white/60 mb-1">
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
Bridge on one of the affected nodes:
</p>
<button
type="button"
onclick={() => copyToClipboard(disableCmd)}
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
title="Click to copy"
>
<span class="flex-1">{disableCmd}</span>
<svg
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
{#if copiedCommand}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M5 13l4 4L19 7"
/>
{:else}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
/>
{/if}
</svg>
</button>
</div>
</div>
{/if}
<!-- Exit topology-only mode button -->
<button
type="button"
@@ -1624,7 +1785,111 @@
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
<div class="absolute top-4 left-4 group" role="alert">
<div
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
>
<svg
class="w-5 h-5 text-yellow-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<span class="text-sm font-mono text-yellow-200">
THUNDERBOLT BRIDGE CYCLE DETECTED
</span>
</div>
<!-- Tooltip on hover -->
<div
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
>
<p class="text-xs text-white/80 mb-2">
A network routing cycle was detected between nodes connected
via Thunderbolt Bridge. This can cause connectivity issues.
</p>
<p class="text-xs text-white/60 mb-2">
<span class="text-yellow-300">Affected nodes:</span>
{cycle.map(getNodeName).join(" → ")}
</p>
<p class="text-xs text-white/60 mb-1">
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
Bridge on one of the affected nodes:
</p>
<button
type="button"
onclick={() => copyToClipboard(disableCmd)}
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
title="Click to copy"
>
<span class="flex-1">{disableCmd}</span>
<svg
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
{#if copiedCommand}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M5 13l4 4L19 7"
/>
{:else}
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
/>
{/if}
</svg>
</button>
</div>
</div>
{/if}
<!-- Node Filter Indicator (top-right corner) -->
{#if isFilterActive()}
<button
onclick={clearPreviewNodeFilter}
class="absolute top-2 right-2 flex items-center gap-1.5 px-2 py-1 bg-exo-dark-gray/80 border border-exo-yellow/40 rounded text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer backdrop-blur-sm"
title="Clear filter"
>
<span class="text-[10px] font-mono tracking-wider">
FILTER: {nodeFilter.size}
</span>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
<!-- Chat Input - Below topology -->
@@ -2061,6 +2326,13 @@
>
{downloadInfo.statusText}
</div>
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
<div
class="text-xs text-red-400/80 font-mono mt-1 break-words"
>
{downloadInfo.errorMessage}
</div>
{/if}
{/if}
</div>
</div>
@@ -2106,6 +2378,9 @@
{@const isImageModel = modelSupportsImageGeneration(
foundModel.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
foundModel.id,
)}
<span
class="flex items-center justify-between gap-2 w-full pr-4"
>
@@ -2132,6 +2407,22 @@
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>
@@ -2204,6 +2495,9 @@
{@const isImageModel = modelSupportsImageGeneration(
model.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
model.id,
)}
<button
type="button"
onclick={() => {
@@ -2244,6 +2538,23 @@
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image editing model"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span
@@ -2564,7 +2875,36 @@
<div
class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden"
>
<TopologyGraph highlightedNodes={highlightedNodes()} />
<TopologyGraph
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
{#if tbBridgeCycles.length > 0}
<div
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
title="Thunderbolt Bridge cycle detected - click to view details"
>
<svg
class="w-3.5 h-3.5 text-yellow-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<span class="text-[10px] font-mono text-yellow-200"
>TB CYCLE</span
>
</div>
{/if}
</div>
</button>
@@ -2993,6 +3333,13 @@
>
{downloadInfo.statusText}
</div>
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
<div
class="text-xs text-red-400/80 font-mono mt-1 break-words"
>
{downloadInfo.errorMessage}
</div>
{/if}
{/if}
</div>
</div>

View File

@@ -6,6 +6,8 @@
type DownloadProgress,
refreshState,
lastUpdate as lastUpdateStore,
startDownload,
deleteDownload,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
@@ -28,6 +30,7 @@
etaMs: number;
status: "completed" | "downloading";
files: FileProgress[];
shardMetadata?: Record<string, unknown>;
};
type NodeEntry = {
@@ -172,33 +175,6 @@
}
let downloadOverview = $state<NodeEntry[]>([]);
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
[],
);
async function fetchModels() {
try {
const response = await fetch("/models");
if (response.ok) {
const data = await response.json();
models = data.data || [];
}
} catch (error) {
console.error("Failed to fetch models:", error);
}
}
function getModelTotalBytes(
modelId: string,
downloadTotalBytes: number,
): number {
if (downloadTotalBytes > 0) return downloadTotalBytes;
const model = models.find((m) => m.id === modelId);
if (model?.storage_size_megabytes) {
return model.storage_size_megabytes * 1024 * 1024;
}
return 0;
}
$effect(() => {
try {
@@ -296,6 +272,12 @@
}
}
// Extract shard_metadata for use with download actions
const shardMetadata = (downloadPayload.shard_metadata ??
downloadPayload.shardMetadata) as
| Record<string, unknown>
| undefined;
const entry: ModelEntry = {
modelId,
prettyName,
@@ -312,6 +294,7 @@
? "completed"
: "downloading",
files,
shardMetadata,
};
const existing = modelMap.get(modelId);
@@ -373,7 +356,6 @@
onMount(() => {
// Ensure we fetch at least once when visiting downloads directly
refreshState();
fetchModels();
});
</script>
@@ -482,7 +464,7 @@
{#if model.status !== "completed"}
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(
getModelTotalBytes(model.modelId, model.totalBytes),
model.totalBytes,
)}
</div>
{/if}
@@ -497,6 +479,52 @@
>
{pct.toFixed(1)}%
</span>
{#if model.status !== "completed" && model.shardMetadata}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(node.nodeId, model.shardMetadata!)}
title="Start download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
{#if model.status === "completed"}
<button
type="button"
class="text-exo-light-gray hover:text-red-400 transition-colors"
onclick={() =>
deleteDownload(node.nodeId, model.modelId)}
title="Delete download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"

View File

@@ -0,0 +1,284 @@
import asyncio
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
# Only process commands targeting this node
if cmd.command.target_node_id != self.node_id:
continue
match cmd.command:
case StartDownload(shard_metadata=shard):
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
)
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
# Check initial status from downloader
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Start actual download
self._start_download_task(shard, initial_progress)
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
model_id = shard.model_card.model_id
# Emit ongoing status
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
finally:
if model_id in self.active_downloads:
del self.active_downloads[model_id]
task = asyncio.create_task(download_wrapper())
self.active_downloads[model_id] = task
async def _delete_download(self, model_id: ModelId) -> None:
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Delete from disk
logger.info(f"Deleting model files for {model_id}")
deleted = await delete_model(model_id)
if deleted:
logger.info(f"Successfully deleted model {model_id}")
else:
logger.warning(f"Model {model_id} was not found on disk")
# Emit pending status to reset UI state, then remove from local tracking
if model_id in self.download_status:
current_status = self.download_status[model_id]
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"
)

View File

@@ -24,6 +24,13 @@ from pydantic import (
TypeAdapter,
)
from exo.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
@@ -35,12 +42,27 @@ from exo.shared.types.worker.downloads import (
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
)
class HuggingFaceAuthenticationError(Exception):
"""Raised when HuggingFace returns 401/403 for a model download."""
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
if status_code == 401 and token is None:
return (
f"Model '{model_id}' requires authentication. "
f"Set HF_TOKEN in the app's Advanced settings, set the HF_TOKEN environment variable, or run `hf auth login`. "
f"Get a token at https://huggingface.co/settings/tokens"
)
elif status_code == 403:
return (
f"Access denied to '{model_id}'. "
f"Please accept the model terms at https://huggingface.co/{model_id}"
)
else:
return f"Authentication failed for '{model_id}' (HTTP {status_code})"
def trim_etag(etag: str) -> str:
@@ -147,6 +169,8 @@ async def fetch_file_list_with_retry(
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
if attempt == n_attempts - 1:
raise e
@@ -167,6 +191,9 @@ async def _fetch_file_list(
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
@@ -256,6 +283,9 @@ async def file_meta(
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(model_id, revision, path, redirected_location)
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -279,6 +309,8 @@ async def download_file_with_retry(
return await _download_file(
model_id, revision, path, target_dir, on_progress
)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
@@ -322,6 +354,9 @@ async def _download_file(
):
if r.status == 404:
raise FileNotFoundError(f"File not found: {url}")
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
)
@@ -463,7 +498,7 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
logger.debug(f"Downloading {shard.model_card.model_id=}")
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
@@ -476,7 +511,7 @@ async def download_shard(
allow_patterns = await resolve_allow_patterns(shard)
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
logger.debug(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(

View File

@@ -68,7 +68,11 @@ def get_hf_home() -> Path:
async def get_hf_token() -> str | None:
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
"""Retrieve the Hugging Face token from HF_TOKEN env var or HF_HOME directory."""
# Check environment variable first
if token := os.environ.get("HF_TOKEN"):
return token
# Fall back to file-based token
token_path = get_hf_home() / "token"
if await aios.path.exists(token_path):
async with aiofiles.open(token_path, "r") as f:

View File

@@ -3,13 +3,15 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
@@ -19,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.load(model_id)
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -166,7 +168,7 @@ class ResumableShardDownloader(ShardDownloader):
yield await task
# TODO: except Exception
except Exception as e:
print("Error downloading shard:", e)
logger.error("Error downloading shard:", e)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -5,13 +5,13 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?

View File

@@ -1,10 +1,11 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
from typing import Iterator, Self
import anyio
from anyio.abc import TaskGroup
@@ -12,6 +13,8 @@ from loguru import logger
from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
@@ -21,7 +24,6 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -29,6 +31,7 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
@@ -36,6 +39,7 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -49,8 +53,26 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -58,6 +80,7 @@ class Node:
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
)
else:
@@ -67,11 +90,12 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -99,13 +123,25 @@ class Node:
election_result_sender=er_send,
)
return cls(router, worker, election, er_recv, master, api, node_id)
return cls(
router,
download_coordinator,
worker,
election,
er_recv,
master,
api,
node_id,
event_index_counter,
)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -170,13 +206,27 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -185,6 +235,10 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -226,6 +280,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -268,6 +323,11 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--no-downloads",
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -1,14 +1,16 @@
import base64
import contextlib
import json
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Literal, cast
from typing import Annotated, Literal, cast
from uuid import uuid4
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -20,7 +22,10 @@ from loguru import logger
from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_IMAGE_CACHE_DIR, EXO_MAX_CHUNK_SIZE
from exo.shared.constants import (
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
)
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
@@ -29,6 +34,7 @@ from exo.shared.models.model_cards import (
ModelId,
)
from exo.shared.types.api import (
AdvancedImageParams,
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
@@ -38,6 +44,7 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
@@ -55,19 +62,32 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
DeleteInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
StartDownload,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -93,7 +113,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -102,7 +122,19 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
if isinstance(chunk, TokenChunk)
else ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
finish_reason=chunk.finish_reason,
)
],
@@ -131,12 +163,14 @@ class API:
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
@@ -162,8 +196,12 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
] = {}
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
@@ -231,6 +269,8 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -310,11 +350,20 @@ class API:
return placements[new_ids[0]]
async def get_placement_previews(
self, model_id: ModelId
self,
model_id: ModelId,
node_ids: Annotated[list[NodeId] | None, Query()] = None,
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
if len(list(self.state.topology.list_nodes())) == 0:
# Create filtered topology if node_ids specified
if node_ids and len(node_ids) > 0:
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
else:
topology = self.state.topology
if len(list(topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
@@ -327,9 +376,7 @@ class API:
instance_combinations.extend(
[
(sharding, instance_meta, i)
for i in range(
1, len(list(self.state.topology.list_nodes())) + 1
)
for i in range(1, len(list(topology.list_nodes())) + 1)
]
)
# TODO: PDD
@@ -347,7 +394,7 @@ class API:
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
topology=topology,
current_instances=self.state.instances,
)
except ValueError as exc:
@@ -439,11 +486,13 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
@@ -462,7 +511,8 @@ class API:
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
del self._chat_completion_queues[command_id]
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
@@ -470,6 +520,7 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
@@ -498,11 +549,12 @@ class API:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
@@ -511,7 +563,18 @@ class API:
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -529,6 +592,7 @@ class API:
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
@@ -539,6 +603,7 @@ class API:
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
@@ -554,7 +619,19 @@ class API:
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
@@ -571,7 +648,7 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
role="assistant", content=combined_text, tool_calls=tool_calls
),
finish_reason=finish_reason,
)
@@ -729,7 +806,9 @@ class API:
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
with recv as chunks:
async for chunk in chunks:
@@ -769,6 +848,7 @@ class API:
# Yield partial image event (always use b64_json for partials)
event_data = {
"type": "partial",
"image_index": chunk.image_index,
"partial_index": partial_idx,
"total_partials": total_partials,
"format": str(chunk.format),
@@ -838,7 +918,9 @@ class API:
stats: ImageGenerationStats | None = None
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
self._image_generation_queues[command_id], recv = channel[
ImageChunk | ErrorChunk
]()
while images_complete < num_images:
with recv as chunks:
@@ -956,6 +1038,9 @@ class API:
stream: bool,
partial_images: int,
bench: bool,
quality: Literal["high", "medium", "low"],
output_format: Literal["png", "jpeg", "webp"],
advanced_params: AdvancedImageParams | None,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
@@ -984,6 +1069,9 @@ class API:
stream=stream,
partial_images=partial_images,
bench=bench,
quality=quality,
output_format=output_format,
advanced_params=advanced_params,
),
)
@@ -994,7 +1082,6 @@ class API:
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
@@ -1019,12 +1106,22 @@ class API:
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
partial_images: str = Form("0"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
# Parse string form values to proper types
stream_bool = stream.lower() in ("true", "1", "yes")
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
@@ -1036,6 +1133,9 @@ class API:
stream=stream_bool,
partial_images=partial_images_int,
bench=False,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
if stream_bool and partial_images_int > 0:
@@ -1066,8 +1166,18 @@ class API:
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
@@ -1079,6 +1189,9 @@ class API:
stream=False,
partial_images=0,
bench=True,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
return await self._collect_image_generation_with_stats(
@@ -1148,27 +1261,26 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(
event.command_id, None
)
elif event.command_id in self._image_generation_queues:
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
queue = self._image_generation_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._chat_completion_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:
@@ -1191,3 +1303,28 @@ class API:
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(
self, payload: StartDownloadParams
) -> StartDownloadResponse:
command = StartDownload(
target_node_id=payload.target_node_id,
shard_metadata=payload.shard_metadata,
)
await self._send_download(command)
return StartDownloadResponse(command_id=command.command_id)
async def delete_download(
self, node_id: NodeId, model_id: ModelId
) -> DeleteDownloadResponse:
command = DeleteDownload(
target_node_id=node_id,
model_id=ModelId(model_id),
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)

View File

@@ -87,12 +87,12 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycle

View File

@@ -197,49 +197,6 @@ def get_shard_assignments(
)
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
cycles = cycle_digraph.get_cycles()
expected_length = len(list(cycle_digraph.list_nodes()))
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
if not cycles:
if expected_length > 1:
logger.warning(
f"No cycles of length {expected_length} found even though chosen subgraph contained {expected_length} nodes"
)
return []
cycle = cycles[0]
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycle):
get_thunderbolt = True
logger.debug(f"Using thunderbolt cycle: {get_thunderbolt}")
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle.node_ids[i]
next_node = cycle.node_ids[(i + 1) % len(cycle)]
for connection in cycle_digraph.get_all_connections_between(
source=current_node, sink=next_node
):
if not isinstance(connection, SocketConnection):
continue
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
return hosts
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
cycle_digraph: Topology,
@@ -265,9 +222,6 @@ def get_mlx_jaccl_devices_matrix(
matrix[i][j] = conn.source_rdma_iface
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
)
@@ -279,22 +233,11 @@ def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
) -> Generator[str, None, None]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
def _find_interface_name_for_ip(
ip_address: str, node_network: NodeNetworkInfo
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_network.interfaces:
if interface.ip_address == ip_address:
return interface.name
return None
yield connection.sink_multiaddr.ip_address
def _find_ip_prioritised(
@@ -303,43 +246,25 @@ def _find_ip_prioritised(
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
Priority order:
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
3. Non-Thunderbolt connections
4. Any other IP address
Priority: ethernet > wifi > unknown > thunderbolt
"""
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(
ip, node_network.get(other_node_id, NodeNetworkInfo())
): ip
for ip, _ in ips
if not ips:
return None
other_network = node_network.get(other_node_id, NodeNetworkInfo())
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
en0_ip = iface_map.get("en0")
if en0_ip:
return en0_ip
en1_ip = iface_map.get("en1")
if en1_ip:
return en1_ip
non_thunderbolt_ip = next(
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
)
if non_thunderbolt_ip:
return non_thunderbolt_ip
if ips:
return ips[0][0]
return None
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
def get_mlx_ring_hosts_by_node(
@@ -381,9 +306,6 @@ def get_mlx_ring_hosts_by_node(
node_id, other_node_id, cycle_digraph, node_network
)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
)
@@ -416,9 +338,6 @@ def get_mlx_jaccl_coordinators(
if ip is not None:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
)

View File

@@ -1,13 +1,9 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from typing import Any
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
@@ -48,95 +44,3 @@ def test_http_exception_handler_formats_openai_style() -> None:
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason
def test_image_chunk_with_error_fields() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="",
chunk_index=0,
total_chunks=1,
image_index=0,
finish_reason="error",
error_message="Image generation failed",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Image generation failed"
assert chunk.data == ""
assert chunk.chunk_index == 0
assert chunk.total_chunks == 1
assert chunk.image_index == 0
def test_image_chunk_without_error() -> None:
chunk = ImageChunk(
idx=0,
model=MODEL_A_ID,
data="base64encodeddata",
chunk_index=0,
total_chunks=1,
image_index=0,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
assert chunk.data == "base64encodeddata"

View File

@@ -3,7 +3,6 @@ import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
@@ -14,7 +13,7 @@ from exo.master.tests.conftest import (
)
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
@@ -273,45 +272,6 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
# act
hosts = get_hosts_from_subgraph(topology)
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip="169.254.0.1", port=1234),
Host(ip="169.254.0.2", port=1234),
Host(ip="169.254.0.3", port=1234),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
# arrange
node_a_id = NodeId()

View File

@@ -3,7 +3,7 @@ from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,
)
@@ -45,3 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -30,6 +30,7 @@ from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -46,6 +47,7 @@ from exo.utils.info_gatherer.info_gatherer import (
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
ThunderboltBridgeInfo,
)
@@ -225,6 +227,21 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
node_thunderbolt_bridge = {
key: value
for key, value in state.node_thunderbolt_bridge.items()
if key != event.node_id
}
# Only recompute cycles if the leaving node had TB bridge enabled
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
leaving_node_had_tb_enabled = (
leaving_node_status is not None and leaving_node_status.enabled
)
thunderbolt_bridge_cycles = (
topology.get_thunderbolt_bridge_cycles(node_thunderbolt_bridge, node_network)
if leaving_node_had_tb_enabled
else [list(cycle) for cycle in state.thunderbolt_bridge_cycles]
)
return state.model_copy(
update={
"downloads": downloads,
@@ -235,6 +252,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
"node_thunderbolt_bridge": node_thunderbolt_bridge,
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
}
)
@@ -312,6 +331,22 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
case ThunderboltBridgeInfo():
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
**state.node_thunderbolt_bridge,
event.node_id: info.status,
}
update["node_thunderbolt_bridge"] = new_tb_bridge
# Only recompute cycles if the enabled status changed
old_status = state.node_thunderbolt_bridge.get(event.node_id)
old_enabled = old_status.enabled if old_status else False
new_enabled = info.status.enabled
if old_enabled != new_enabled:
update["thunderbolt_bridge_cycles"] = (
topology.get_thunderbolt_bridge_cycles(
new_tb_bridge, state.node_network
)
)
return state.model_copy(update=update)

View File

@@ -49,3 +49,7 @@ LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_ENABLE_IMAGE_MODELS = (
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
)

View File

@@ -9,6 +9,7 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt, field_validator
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
@@ -410,161 +411,166 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
# "flux1-schnell": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23782357120),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "flux1-dev": ModelCard(
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
# n_layers=57,
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(0),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="text_encoder_2",
# component_path="text_encoder_2/",
# storage_size=Memory.from_bytes(9524621312),
# n_layers=24,
# can_shard=False,
# safetensors_index_filename="model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(23802816640),
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
# "qwen-image-edit-2509": ModelCard(
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
# hidden_size=1,
# supports_tensor=False,
# tasks=[ModelTask.ImageToImage],
# components=[
# ComponentInfo(
# component_name="text_encoder",
# component_path="text_encoder/",
# storage_size=Memory.from_kb(16584333312),
# n_layers=12,
# can_shard=False,
# safetensors_index_filename=None, # Single file
# ),
# ComponentInfo(
# component_name="transformer",
# component_path="transformer/",
# storage_size=Memory.from_bytes(40860802176),
# n_layers=60,
# can_shard=True,
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
# ),
# ComponentInfo(
# component_name="vae",
# component_path="vae/",
# storage_size=Memory.from_kb(0),
# n_layers=None,
# can_shard=False,
# safetensors_index_filename=None,
# ),
# ],
# ),
}
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
@@ -615,7 +621,7 @@ class ConfigData(BaseModel):
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
@@ -627,7 +633,7 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
@@ -637,11 +643,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
@@ -650,7 +656,7 @@ async def get_safetensors_size(model_id: ModelId) -> Memory:
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)

View File

@@ -7,6 +7,11 @@ import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import (
InterfaceType,
NodeNetworkInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.topology import (
Connection,
Cycle,
@@ -188,24 +193,25 @@ class Topology:
cycles.append(Cycle(node_ids=[node_id]))
return cycles
def get_cycles_tb(self) -> list[Cycle]:
tb_edges = [
def get_rdma_cycles(self) -> list[Cycle]:
rdma_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
if isinstance(conn, RDMAConnection)
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
rdma_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = (
rx.PyDiGraph()
)
rdma_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
for u, v, conn in rdma_edges:
rdma_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycle_idxs = rx.simple_cycles(rdma_graph)
cycles: list[Cycle] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
cycle = Cycle(node_ids=[rdma_graph[idx] for idx in cycle_idx])
cycles.append(cycle)
return cycles
@@ -219,18 +225,83 @@ class Topology:
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
def is_rdma_cycle(self, cycle: Cycle) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:
continue
has_tb = False
has_rdma = False
for edge in self._graph.get_all_edge_data(rid, neighbor_rid):
if edge.is_thunderbolt():
has_tb = True
if isinstance(edge, RDMAConnection):
has_rdma = True
break
if not has_tb:
if not has_rdma:
return False
return True
def get_thunderbolt_bridge_cycles(
self,
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
"""
enabled_nodes = {
node_id
for node_id, status in node_tb_bridge_status.items()
if status.enabled
}
if len(enabled_nodes) < 3:
return []
thunderbolt_ips = _get_ips_with_interface_type(
enabled_nodes, node_network, "thunderbolt"
)
# Build subgraph with only TB bridge enabled nodes and thunderbolt connections
graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = rx.PyDiGraph()
node_to_idx: dict[NodeId, int] = {}
for node_id in enabled_nodes:
if node_id in self._vertex_indices:
node_to_idx[node_id] = graph.add_node(node_id)
for u, v, conn in self._graph.weighted_edge_list():
source_id, sink_id = self._graph[u], self._graph[v]
if source_id not in node_to_idx or sink_id not in node_to_idx:
continue
# Include connection if it's over a thunderbolt interface
if (
isinstance(conn, SocketConnection)
and conn.sink_multiaddr.ip_address in thunderbolt_ips
):
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
if isinstance(conn, RDMAConnection):
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
return [
[graph[idx] for idx in cycle]
for cycle in rx.simple_cycles(graph)
if len(cycle) > 2
]
def _get_ips_with_interface_type(
node_ids: set[NodeId],
node_network: Mapping[NodeId, NodeNetworkInfo],
interface_type: InterfaceType,
) -> set[str]:
"""Get all IP addresses on interfaces of the specified type for the given nodes."""
ips: set[str] = set()
for node_id in node_ids:
network_info = node_network.get(node_id, NodeNetworkInfo())
for iface in network_info.interfaces:
if iface.interface_type == interface_type:
ips.add(iface.ip_address)
return ips

View File

@@ -7,10 +7,11 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -54,6 +55,18 @@ class ChatCompletionMessageText(BaseModel):
text: str
class ToolCallItem(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
index: int | None = None
type: Literal["function"] = "function"
function: ToolCallItem
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: (
@@ -61,7 +74,7 @@ class ChatCompletionMessage(BaseModel):
) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None
@@ -340,3 +353,16 @@ class ImageListItem(BaseModel, frozen=True):
class ImageListResponse(BaseModel, frozen=True):
data: list[ImageListItem]
class StartDownloadParams(CamelCaseModel):
target_node_id: NodeId
shard_metadata: ShardMetadata
class StartDownloadResponse(CamelCaseModel):
command_id: CommandId
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator
from enum import Enum
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
@@ -8,24 +7,29 @@ from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
class ChunkType(str, Enum):
Token = "Token"
Image = "Image"
from .worker.runner_response import ToolCallItem
class BaseChunk(TaggedModel):
idx: int
model: ModelId
class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: FinishReason | None = None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
class ErrorChunk(BaseChunk):
error_message: str
finish_reason: Literal["error"] = "error"
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):
@@ -63,4 +67,4 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk

View File

@@ -1,6 +1,6 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
@@ -9,7 +9,7 @@ from exo.shared.types.api import (
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -62,6 +62,19 @@ class RequestEventLog(BaseCommand):
since_idx: int
class StartDownload(BaseCommand):
target_node_id: NodeId
shard_metadata: ShardMetadata
class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
@@ -79,3 +92,8 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
command: DownloadCommand

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Self
from typing import Literal, Self
import psutil
@@ -48,9 +48,13 @@ class SystemPerformanceProfile(CamelCaseModel):
ecpu_usage: float = 0.0
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
class NetworkInterfaceInfo(CamelCaseModel):
name: str
ip_address: str
interface_type: InterfaceType = "unknown"
class NodeIdentity(CamelCaseModel):
@@ -71,3 +75,11 @@ class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []
class ThunderboltBridgeStatus(CamelCaseModel):
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
enabled: bool
exists: bool
service_name: str | None = None

View File

@@ -13,6 +13,7 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
ThunderboltBridgeStatus,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
@@ -51,6 +52,10 @@ class State(CamelCaseModel):
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:

View File

@@ -21,9 +21,6 @@ class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
def is_thunderbolt(self) -> bool:
return True
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
@@ -31,9 +28,6 @@ class SocketConnection(FrozenModel):
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
class Connection(FrozenModel):
source: NodeId

View File

@@ -1,7 +1,12 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
from exo.shared.types.api import (
FinishReason,
GenerationStats,
ImageGenerationStats,
ToolCallItem,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -25,6 +30,7 @@ class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
stats: ImageGenerationStats | None = None
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
@@ -39,6 +45,7 @@ class PartialImageResponse(BaseRunnerResponse):
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
@@ -48,5 +55,9 @@ class PartialImageResponse(BaseRunnerResponse):
yield name, value
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -50,9 +50,7 @@ class RunnerReady(BaseRunnerStatus):
class RunnerRunning(BaseRunnerStatus):
"""Runner is processing requests and can accept more (continuous batching)."""
active_requests: int = 0
pass
class RunnerShuttingDown(BaseRunnerStatus):

View File

@@ -19,6 +19,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
ThunderboltBridgeStatus,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
@@ -34,6 +35,142 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
IS_DARWIN = sys.platform == "darwin"
async def _get_thunderbolt_devices() -> set[str] | None:
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listallhardwareports"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listallhardwareports failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
output = result.stdout.decode()
thunderbolt_devices: set[str] = set()
current_port: str | None = None
for line in output.splitlines():
line = line.strip()
if line.startswith("Hardware Port:"):
current_port = line.split(":", 1)[1].strip()
elif line.startswith("Device:") and current_port:
device = line.split(":", 1)[1].strip()
if "thunderbolt" in current_port.lower():
thunderbolt_devices.add(device)
current_port = None
return thunderbolt_devices
async def _get_bridge_services() -> dict[str, str] | None:
"""Get mapping of bridge device -> service name from network service order.
Returns None if the networksetup command fails.
"""
result = await anyio.run_process(
["networksetup", "-listnetworkserviceorder"],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -listnetworkserviceorder failed with code "
f"{result.returncode}: {result.stderr.decode()}"
)
return None
# Parse service order to find bridge devices and their service names
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
service_order_output = result.stdout.decode()
bridge_services: dict[str, str] = {} # device -> service name
current_service: str | None = None
for line in service_order_output.splitlines():
line = line.strip()
# Match "(N) Service Name" or "(*) Service Name" (disabled)
# but NOT "(Hardware Port: ...)" lines
if (
line
and line.startswith("(")
and ")" in line
and not line.startswith("(Hardware Port:")
):
paren_end = line.index(")")
if paren_end + 2 <= len(line):
current_service = line[paren_end + 2 :]
# Match "(Hardware Port: ..., Device: bridgeX)"
elif current_service and "Device: bridge" in line:
# Extract device name from "..., Device: bridge0)"
device_start = line.find("Device: ") + len("Device: ")
device_end = line.find(")", device_start)
if device_end > device_start:
device = line[device_start:device_end]
bridge_services[device] = current_service
return bridge_services
async def _get_bridge_members(bridge_device: str) -> set[str]:
"""Get member interfaces of a bridge device via ifconfig."""
result = await anyio.run_process(
["ifconfig", bridge_device],
check=False,
)
if result.returncode != 0:
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
return set()
members: set[str] = set()
ifconfig_output = result.stdout.decode()
for line in ifconfig_output.splitlines():
line = line.strip()
if line.startswith("member:"):
parts = line.split()
if len(parts) > 1:
members.add(parts[1])
return members
async def _find_thunderbolt_bridge(
bridge_services: dict[str, str], thunderbolt_devices: set[str]
) -> str | None:
"""Find the service name of a bridge containing Thunderbolt interfaces.
Returns the service name if found, None otherwise.
"""
for bridge_device, service_name in bridge_services.items():
members = await _get_bridge_members(bridge_device)
if members & thunderbolt_devices: # intersection is non-empty
return service_name
return None
async def _is_service_enabled(service_name: str) -> bool | None:
"""Check if a network service is enabled.
Returns True if enabled, False if disabled, None on error.
"""
result = await anyio.run_process(
["networksetup", "-getnetworkserviceenabled", service_name],
check=False,
)
if result.returncode != 0:
logger.warning(
f"networksetup -getnetworkserviceenabled '{service_name}' "
f"failed with code {result.returncode}: {result.stderr.decode()}"
)
return None
stdout = result.stdout.decode().strip().lower()
return stdout == "enabled"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
@@ -58,6 +195,66 @@ class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class ThunderboltBridgeInfo(TaggedModel):
status: ThunderboltBridgeStatus
@classmethod
async def gather(cls) -> Self | None:
"""Check if a Thunderbolt Bridge network service is enabled on this node.
Detection approach:
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
2. Find bridge devices from network service order (not hardware ports, as
bridges may not appear there)
3. Check each bridge's members via ifconfig
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
5. Check if that network service is enabled
"""
if not IS_DARWIN:
return None
def _no_bridge_status() -> Self:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=False, service_name=None
)
)
try:
tb_devices = await _get_thunderbolt_devices()
if tb_devices is None:
return _no_bridge_status()
bridge_services = await _get_bridge_services()
if not bridge_services:
return _no_bridge_status()
tb_service_name = await _find_thunderbolt_bridge(
bridge_services, tb_devices
)
if not tb_service_name:
return _no_bridge_status()
enabled = await _is_service_enabled(tb_service_name)
if enabled is None:
return cls(
status=ThunderboltBridgeStatus(
enabled=False, exists=True, service_name=tb_service_name
)
)
return cls(
status=ThunderboltBridgeStatus(
enabled=enabled,
exists=True,
service_name=tb_service_name,
)
)
except Exception as e:
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
return None
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@@ -111,6 +308,7 @@ GatheredInfo = (
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| ThunderboltBridgeInfo
| NodeConfig
| MiscData
| StaticNodeInformation
@@ -125,6 +323,7 @@ class InfoGatherer:
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
@@ -133,6 +332,7 @@ class InfoGatherer:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._monitor_thunderbolt_bridge_status)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
@@ -200,12 +400,23 @@ class InfoGatherer:
return
old_nics = []
while True:
nics = get_network_interfaces()
nics = await get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return

View File

@@ -5,7 +5,7 @@ from subprocess import CalledProcessError
import psutil
from anyio import run_process
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo
async def get_friendly_name() -> str:
@@ -16,8 +16,7 @@ async def get_friendly_name() -> str:
"""
hostname = socket.gethostname()
# TODO: better non mac support
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
if sys.platform != "darwin":
return hostname
try:
@@ -28,7 +27,41 @@ async def get_friendly_name() -> str:
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
"""Parse networksetup -listallhardwareports to get interface types."""
if sys.platform != "darwin":
return {}
try:
result = await run_process(["networksetup", "-listallhardwareports"])
except CalledProcessError:
return {}
types: dict[str, InterfaceType] = {}
current_type: InterfaceType = "unknown"
for line in result.stdout.decode().splitlines():
if line.startswith("Hardware Port:"):
port_name = line.split(":", 1)[1].strip()
if "Wi-Fi" in port_name:
current_type = "wifi"
elif "Ethernet" in port_name or "LAN" in port_name:
current_type = "ethernet"
elif port_name.startswith("Thunderbolt"):
current_type = "thunderbolt"
else:
current_type = "unknown"
elif line.startswith("Device:"):
device = line.split(":", 1)[1].strip()
# enX is ethernet adapters or thunderbolt - these must be deprioritised
if device.startswith("en") and device not in ["en0", "en1"]:
current_type = "maybe_ethernet"
types[device] = current_type
return types
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
"""
Retrieves detailed network interface information on macOS.
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
@@ -36,13 +69,18 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
Returns a list of NetworkInterfaceInfo objects.
"""
interfaces_info: list[NetworkInterfaceInfo] = []
interface_types = await _get_interface_types_from_networksetup()
for iface, services in psutil.net_if_addrs().items():
for service in services:
match service.family:
case socket.AF_INET | socket.AF_INET6:
interfaces_info.append(
NetworkInterfaceInfo(name=iface, ip_address=service.address)
NetworkInterfaceInfo(
name=iface,
ip_address=service.address,
interface_type=interface_types.get(iface, "unknown"),
)
)
case _:
pass

View File

@@ -0,0 +1,32 @@
import time
from typing import Generic, TypeVar
K = TypeVar("K")
class KeyedBackoff(Generic[K]):
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):
self._base = base
self._cap = cap
self._attempts: dict[K, int] = {}
self._last_time: dict[K, float] = {}
def should_proceed(self, key: K) -> bool:
"""Returns True if enough time has elapsed since last attempt."""
now = time.monotonic()
last = self._last_time.get(key, 0.0)
attempts = self._attempts.get(key, 0)
delay = min(self._cap, self._base * (2.0**attempts))
return now - last >= delay
def record_attempt(self, key: K) -> None:
"""Record that an attempt was made for this key."""
self._last_time[key] = time.monotonic()
self._attempts[key] = self._attempts.get(key, 0) + 1
def reset(self, key: K) -> None:
"""Reset backoff state for a key (e.g., on success)."""
self._attempts.pop(key, None)
self._last_time.pop(key, None)

View File

@@ -6,10 +6,10 @@ import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.download.download_utils import build_model_path
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,

View File

@@ -75,19 +75,20 @@ def generate_image(
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
advanced_params = task.advanced_params
if advanced_params is not None and advanced_params.seed is not None:
seed = advanced_params.seed
base_seed = advanced_params.seed
else:
seed = random.randint(0, 2**32 - 1)
base_seed = random.randint(0, 2**32 - 1)
is_bench = getattr(task, "bench", False)
num_images = task.n or 1
generation_start_time: float = 0.0
@@ -95,7 +96,11 @@ def generate_image(
mx.reset_peak_memory()
generation_start_time = time.perf_counter()
partial_images = task.partial_images or (3 if task.stream else 0)
partial_images = (
task.partial_images
if task.partial_images is not None
else (3 if task.stream else 0)
)
image_path: Path | None = None
@@ -105,72 +110,81 @@ def generate_image(
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
for image_num in range(num_images):
# Increment seed for each image to ensure unique results
current_seed = base_seed + image_num
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
image = result
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=current_seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
stats: ImageGenerationStats | None = None
if is_bench:
generation_end_time = time.perf_counter()
total_generation_time = generation_end_time - generation_start_time
num_inference_steps = model.get_steps_for_quality(quality)
seconds_per_step = (
total_generation_time / num_inference_steps
if num_inference_steps > 0
else 0.0
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
image_index=image_num,
)
else:
image = result
peak_memory_gb = mx.get_peak_memory() / (1024**3)
# Only include stats on the final image
stats: ImageGenerationStats | None = None
if is_bench and image_num == num_images - 1:
generation_end_time = time.perf_counter()
total_generation_time = (
generation_end_time - generation_start_time
)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=task.n or 1,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
num_inference_steps = model.get_steps_for_quality(quality)
total_steps = num_inference_steps * num_images
seconds_per_step = (
total_generation_time / total_steps
if total_steps > 0
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
image_index=image_num,
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
)

View File

@@ -1,302 +0,0 @@
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
import time
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import MAX_TOKENS
from exo.worker.engines.mlx.generator.distributed_sync import share_object
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.runner.bootstrap import logger
@dataclass
class ActiveRequest:
"""Tracks an active request in the batch."""
command_id: CommandId
task_id: TaskId
uid: int # BatchGenerator's internal ID
detokenizer: StreamingDetokenizer
tokens_generated: int = 0
prompt_tokens: int = 0
start_time: float = field(default_factory=time.perf_counter)
@dataclass
class BatchedGenerationResponse:
"""Response from batch engine, tagged with command_id and task_id."""
command_id: CommandId
task_id: TaskId
response: GenerationResponse
class BatchGenerationEngine:
"""Manages continuous batching using mlx_lm's BatchGenerator."""
def __init__(
self,
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
max_tokens: int = MAX_TOKENS,
completion_batch_size: int = 32,
prefill_batch_size: int = 8,
prefill_step_size: int = 2048,
):
self.model = model
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.active_requests: dict[int, ActiveRequest] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._pending_completions: list[
int
] = [] # UIDs completed but not yet synced/removed
self.group = group
self.rank = group.rank() if group else 0
self.is_distributed = group is not None and group.size() > 1
sampler = make_sampler(temp=0.7, top_p=1.0)
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
self.batch_gen: BatchGenerator = BatchGenerator(
model=model,
max_tokens=max_tokens,
stop_tokens=eos_tokens,
sampler=sampler,
completion_batch_size=completion_batch_size,
prefill_batch_size=prefill_batch_size,
prefill_step_size=prefill_step_size,
)
logger.info(
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
)
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion. Only rank 0 should call this.
In distributed mode, rank 0 receives tasks from the control plane and
queues them here. The actual insertion happens in sync_and_insert_pending()
which ensures all ranks insert the same requests together.
"""
assert self.rank == 0, "Only rank 0 should queue requests"
self._pending_inserts.append((command_id, task_id, task_params))
logger.info(
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})"
)
def sync_and_insert_pending(self) -> list[int]:
"""Sync pending inserts across ranks and insert them. Returns UIDs.
This method ensures all ranks insert the same requests in the same order.
In non-distributed mode, it simply inserts all pending requests.
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
Batches all pending inserts into a single batch_gen.insert() call for
efficient prefill batching.
"""
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
if not self.is_distributed:
# Non-distributed: just insert directly from pending
inserts_to_process = list(self._pending_inserts)
else:
# Distributed: broadcast pending inserts from rank 0 to all ranks
assert self.group is not None
pending_data = self._pending_inserts if self.rank == 0 else None
synced_data = share_object(pending_data, self.rank, self.group)
if synced_data is None:
self._pending_inserts.clear()
return []
inserts_to_process = synced_data
if not inserts_to_process:
self._pending_inserts.clear()
return []
# Prepare all requests for batched insertion
all_tokens: list[list[int]] = []
all_max_tokens: list[int] = []
all_prompt_tokens: list[int] = []
request_info: list[tuple[CommandId, TaskId]] = []
for cmd_id, task_id, params in inserts_to_process:
prompt_str = apply_chat_template(self.tokenizer, params)
tokens: list[int] = self.tokenizer.encode(
prompt_str, add_special_tokens=False
)
max_tokens = params.max_tokens or self.max_tokens
all_tokens.append(tokens)
all_max_tokens.append(max_tokens)
all_prompt_tokens.append(len(tokens))
request_info.append((cmd_id, task_id))
# Single batched insert for efficient prefill
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
# Track all inserted requests
for i, uid in enumerate(uids):
cmd_id, task_id = request_info[i]
self.active_requests[uid] = ActiveRequest(
command_id=cmd_id,
task_id=task_id,
uid=uid,
detokenizer=self.tokenizer.detokenizer,
prompt_tokens=all_prompt_tokens[i],
)
logger.info(
f"Inserted request {cmd_id} with uid={uid}, prompt_tokens={all_prompt_tokens[i]}, max_tokens={all_max_tokens[i]}"
)
self._pending_inserts.clear()
return uids
def step(self) -> list[BatchedGenerationResponse]:
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
responses = self.batch_gen.next()
if not responses:
return []
results: list[BatchedGenerationResponse] = []
for r in responses:
uid: int = r.uid
req = self.active_requests.get(uid)
if req is None:
logger.warning(f"Received response for unknown uid={uid}")
continue
req.tokens_generated += 1
# Decode the token
token: int = r.token
req.detokenizer.add_token(token)
text: str = req.detokenizer.last_segment
stats: GenerationStats | None = None
finish_reason: FinishReason | None = None
raw_finish_reason: str | None = r.finish_reason
if raw_finish_reason is not None:
# Finalize to get remaining text
req.detokenizer.finalize()
text = req.detokenizer.last_segment
elapsed = time.perf_counter() - req.start_time
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
stats = GenerationStats(
prompt_tps=0.0, # Not tracked per-request in batch mode
generation_tps=generation_tps,
prompt_tokens=req.prompt_tokens,
generation_tokens=req.tokens_generated,
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
)
if raw_finish_reason == "stop":
finish_reason = "stop"
elif raw_finish_reason == "length":
finish_reason = "length"
else:
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
finish_reason = "stop"
# Track completion but don't remove yet - wait for sync_completions()
self._pending_completions.append(uid)
logger.info(
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
)
results.append(
BatchedGenerationResponse(
command_id=req.command_id,
task_id=req.task_id,
response=GenerationResponse(
text=text, token=token, finish_reason=finish_reason, stats=stats
),
)
)
# In non-distributed mode, clean up completions immediately
if not self.is_distributed:
self._remove_completed()
return results
def sync_completions(self) -> None:
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
if not self.is_distributed:
# Non-distributed: early return if nothing to do
if not self._pending_completions:
return
self._remove_completed()
return
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
# This prevents deadlock if one rank has completions and another doesn't
assert self.group is not None
synced_uids = share_object(
self._pending_completions if self.rank == 0 else None,
self.rank,
self.group,
)
if synced_uids:
self._pending_completions = synced_uids
self._remove_completed()
def _remove_completed(self) -> None:
"""Remove completed requests from tracking."""
for uid in self._pending_completions:
if uid in self.active_requests:
del self.active_requests[uid]
self._pending_completions.clear()
@property
def has_active_requests(self) -> bool:
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def active_count(self) -> int:
return len(self.active_requests)
@property
def pending_count(self) -> int:
return len(self.batch_gen.unprocessed_prompts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def has_pending_completions(self) -> bool:
return bool(self._pending_completions)

View File

@@ -1,30 +0,0 @@
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
# pyright: reportAny=false
import pickle
from typing import TypeVar, cast
import mlx.core as mx
T = TypeVar("T")
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
if rank == 0:
if obj is None:
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
return None
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
mx.eval(mx.distributed.all_sum(data, group=group))
return obj
else:
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
if size == 0:
return None
data = mx.zeros(size, dtype=mx.uint8)
data = mx.distributed.all_sum(data, group=group)
mx.eval(data)
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))

View File

@@ -1,104 +0,0 @@
"""Time budget iterator for controlling generation loop timing in distributed mode.
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
rather than syncing every token. This reduces distributed sync overhead.
"""
import time
from typing import Iterator
import mlx.core as mx
from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
class TimeBudget(Iterator[None]):
"""Controls generation loop timing, syncing across ranks periodically.
In distributed mode, periodically syncs timing across all ranks to
dynamically adjust iteration count based on actual performance.
In non-distributed mode, simply runs for the time budget.
Usage:
for _ in TimeBudget(budget=0.5):
batch_engine.step()
# ... process responses ...
"""
def __init__(
self,
budget: float = 0.5,
iterations: int = 25,
sync_frequency: int = 10,
group: mx.distributed.Group | None = None,
):
"""Initialize TimeBudget.
Args:
budget: Time budget in seconds before yielding control
iterations: Initial number of iterations per budget period (distributed only)
sync_frequency: How often to sync timing across ranks (distributed only)
group: Distributed group, or None for non-distributed mode
"""
self._budget = budget
self._iterations = iterations
self._sync_frequency = sync_frequency
self._group = group
self._is_distributed = group is not None and group.size() > 1
# Runtime state
self._start: float = 0.0
self._current_iterations: int = 0
self._loops: int = 0
self._time_spent: float = 0.0
def __iter__(self) -> "TimeBudget":
self._start = time.perf_counter()
self._current_iterations = 0
return self
def __next__(self) -> None:
if not self._is_distributed:
# Non-distributed: just check time budget
if time.perf_counter() - self._start > self._budget:
raise StopIteration()
return None
# Distributed mode: iteration-based with periodic timing sync
self._current_iterations += 1
if self._current_iterations > self._iterations:
self._loops += 1
self._time_spent += time.perf_counter() - self._start
if self._loops % self._sync_frequency == 0:
# Sync timing across all ranks
assert self._group is not None
with mx.stream(generation_stream):
time_array = mx.array([self._time_spent], dtype=mx.float32)
total_time = mx.distributed.all_sum(time_array, group=self._group)
mx.eval(total_time)
loop_time = float(total_time.item())
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
if avg_loop_time > 0:
factor = self._budget / avg_loop_time
self._iterations = max(round(self._iterations * factor), 1)
logger.debug(
f"TimeBudget adjusted iterations to {self._iterations}"
)
self._loops = 0
self._time_spent = 0.0
raise StopIteration()
return None
@property
def iterations(self) -> int:
"""Current iterations per budget period."""
return self._iterations

View File

@@ -41,6 +41,7 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
@@ -55,7 +56,6 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
return tokenizer
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
"""
Normalize tool_calls in a message dict.
OpenAI format has tool_calls[].function.arguments as a JSON string,
but some chat templates (e.g., GLM) expect it as a dict.
"""
tool_calls = msg_dict.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
return
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
if not isinstance(tc, dict):
continue
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if not isinstance(func, dict):
continue
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if isinstance(args, str):
with contextlib.suppress(json.JSONDecodeError):
func["arguments"] = json.loads(args)
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
tools = chat_task_data.tools
formatted_messages: list[dict[str, Any]] = []
for message in messages:
@@ -386,15 +409,19 @@ def apply_chat_template(
continue
# Null values are not valid when applying templates in tokenizer
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
dumped: dict[str, Any] = message.model_dump()
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
_normalize_tool_calls(msg_dict)
formatted_messages.append(msg_dict)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
tools=tools,
)
logger.info(prompt)

View File

@@ -1,8 +1,9 @@
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from loguru import logger
@@ -10,7 +11,12 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
@@ -18,7 +24,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -36,22 +41,12 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.utils.keyed_backoff import KeyedBackoff
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -61,7 +56,6 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
@@ -69,23 +63,22 @@ class Worker:
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.local_event_index = 0
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
@@ -100,6 +93,8 @@ class Worker:
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
async def run(self):
logger.info("Starting Worker")
@@ -110,7 +105,6 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -120,6 +114,7 @@ class Worker:
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
@@ -178,11 +173,9 @@ class Worker:
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
self.runners,
self.download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -206,42 +199,26 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
model_id = shard.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -386,78 +363,17 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=self.local_event_index,
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
@@ -505,42 +421,3 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,7 +2,6 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ChatCompletion,
@@ -20,6 +19,7 @@ from exo.shared.types.tasks import (
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadProgress,
)
@@ -44,9 +44,6 @@ def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
@@ -58,7 +55,7 @@ def plan(
return (
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
@@ -114,15 +111,22 @@ def _create_runner(
def _model_needs_download(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> DownloadModel | None:
local_downloads = global_download_status.get(node_id, [])
download_status = {
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
}
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
download_status[model_id],
(DownloadOngoing, DownloadCompleted, DownloadFailed),
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
@@ -291,14 +295,12 @@ def _pending_tasks(
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
if task.task_id in runner.completed or task.task_id in runner.pending:
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

File diff suppressed because it is too large Load Diff

View File

@@ -105,7 +105,7 @@ class RunnerSupervisor:
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown successfully, terminating")
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
@@ -128,11 +128,9 @@ class RunnerSupervisor:
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(f"Skipping task {task.task_id} - already completed")
return
if task.task_id in self.pending:
logger.info(f"Skipping task {task.task_id} - already pending")
return
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -151,17 +149,13 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
# Just set the event to unblock start_task, but keep in pending
# to prevent duplicate forwarding until completion
if event.task_id in self.pending:
self.pending[event.task_id].set()
self.pending.pop(event.task_id).set()
continue
if isinstance(event, TaskStatusUpdated) and event.task_status in (
TaskStatus.Complete,
TaskStatus.TimedOut,
TaskStatus.Failed,
if (
isinstance(event, TaskStatusUpdated)
and event.task_status == TaskStatus.Complete
):
# If a task has just finished, we should be working on it.
# If a task has just been completed, we should be working on it.
assert isinstance(
self.status,
(
@@ -172,8 +166,6 @@ class RunnerSupervisor:
RunnerShuttingDown,
),
)
# Now safe to remove from pending and add to completed
self.pending.pop(event.task_id, None)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:

View File

@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -11,12 +11,12 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,

View File

@@ -1,5 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -45,13 +45,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=download_status,
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -92,14 +88,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
RUNNER_2_ID: RunnerConnected(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -116,7 +104,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -148,23 +135,19 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
# Global state shows shard is downloaded for NODE_A
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [],
NODE_A: [
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -202,12 +185,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -220,7 +197,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -245,7 +221,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,

View File

@@ -47,8 +47,7 @@ def test_plan_kills_runner_when_instance_missing():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
runners=runners, # type: ignore[arg-type]
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -87,8 +86,7 @@ def test_plan_kills_runner_when_sibling_failed():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
runners=runners, # type: ignore[arg-type]
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -120,7 +118,6 @@ def test_plan_creates_runner_when_missing_for_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners,
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -158,8 +155,7 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
runners=runners, # type: ignore[arg-type]
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -189,7 +185,6 @@ def test_plan_does_not_create_runner_for_unassigned_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,

View File

@@ -65,7 +65,6 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -113,7 +112,6 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -158,7 +156,6 @@ def test_plan_does_not_forward_tasks_for_other_instances():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -221,7 +218,6 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -261,7 +257,6 @@ def test_plan_returns_none_when_nothing_to_do():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -57,7 +57,6 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -99,7 +98,6 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -140,7 +138,6 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -185,7 +182,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -202,7 +198,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -246,7 +241,6 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -289,7 +283,6 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -331,7 +324,6 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -1,330 +0,0 @@
"""
Tests for continuous batching behavior in the runner.
These tests verify that:
1. Single requests work through the batch path
2. Multiple concurrent requests batch together
3. Tokens are routed to the correct requests
4. Requests complete at different times appropriately
"""
# pyright: reportAny=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportAttributeAccessIssue=false
# pyright: reportInvalidTypeVarUse=false
from typing import Any
from unittest.mock import MagicMock
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import RunnerRunning
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
class FakeBatchEngineWithTokens:
"""
Fake batch engine that generates a specified number of tokens per request.
This simulates realistic batch generation behavior where:
- Requests are queued on insert
- Each step() call generates one token for all active requests
- Requests complete when they've generated all their tokens
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._uid_counter = 0
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, task_params in self._pending_inserts:
uid = self._do_insert(command_id, task_id, task_params)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def _do_insert(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams | None,
) -> int:
uid = self._uid_counter
self._uid_counter += 1
# Track: (command_id, task_id, tokens_generated, max_tokens)
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
return uid
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
uids_to_remove: list[int] = []
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
self._active_requests.items()
):
tokens_gen += 1
finish_reason = "stop" if tokens_gen >= max_tokens else None
text = f"token{tokens_gen}"
if finish_reason:
uids_to_remove.append(uid)
else:
self._active_requests[uid] = (
command_id,
task_id,
tokens_gen,
max_tokens,
)
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=tokens_gen,
text=text,
finish_reason=finish_reason,
),
)
)
for uid in uids_to_remove:
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def is_distributed(self) -> bool:
return False # Non-distributed mode for testing
class FakeGroup:
"""Fake MLX distributed group for testing."""
def size(self) -> int:
return 1 # Single node (non-distributed)
def make_nothin[T, U, V](res: T):
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
@pytest.fixture
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
)
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
"""
Run tasks through the runner, adding shutdown at the end.
Tasks are sent in order, with shutdown sent last.
The batch engine processes between task handling.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NodeId(NODE_A),
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
shutdown_task = Shutdown(
task_id=TaskId("shutdown"),
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
with task_sender, event_receiver:
# Send all tasks including shutdown
for t in tasks:
task_sender.send(t)
task_sender.send(shutdown_task)
# Disable cleanup methods to prevent issues
event_sender.close = lambda: None
event_sender.join = lambda: None
task_receiver.close = lambda: None
task_receiver.join = lambda: None
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
def make_chat_task(
task_id: str, command_id: str, max_tokens: int = 3
) -> ChatCompletion:
return ChatCompletion(
task_id=TaskId(task_id),
command_id=CommandId(command_id),
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=max_tokens,
),
instance_id=INSTANCE_1_ID,
)
def test_single_request_generates_tokens(patch_batch_engine: None):
"""
Verify a single request generates the expected tokens through the batch path.
Note: With the current non-blocking design, shutdown is processed before
batch steps run when all tasks are queued together. This test verifies
the runner status reflects active requests.
"""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events - this shows the request was inserted
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
"""Verify RunnerRunning status includes active_requests count."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_chat_task_acknowledged(patch_batch_engine: None):
"""Verify chat completion task is acknowledged with proper status updates."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find the chat task status events
chat_running = [
e
for e in events
if isinstance(e, TaskStatusUpdated)
and e.task_id == TaskId("chat1")
and e.task_status == TaskStatus.Running
]
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
def test_multiple_requests_tracked(patch_batch_engine: None):
"""Verify multiple concurrent requests are tracked in active_requests."""
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
# Should have at least 2 RunnerRunning events (one per request inserted)
assert len(running_events) >= 2, (
f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
)
# First should have 1 active request, second should have 2
assert running_events[0].runner_status.active_requests == 1
assert running_events[1].runner_status.active_requests == 2

View File

@@ -1,17 +1,12 @@
# Check tasks are complete before runner is ever ready.
# pyright: reportAny=false
from collections.abc import Iterable
from typing import Any, Callable
from unittest.mock import MagicMock
from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -27,7 +22,6 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
@@ -44,9 +38,6 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -116,100 +107,22 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
assert test_event == true_event, f"{test_event} != {true_event}"
class FakeBatchEngine:
"""
Fake batch engine for testing.
Queues requests on insert, returns one token per step.
The runner's non-blocking loop drains all tasks before running batch steps,
so this engine queues requests and has_active_requests returns True only
after at least one request has been inserted.
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._uid_counter = 0
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, _task_params in self._pending_inserts:
uid = self._uid_counter
self._uid_counter += 1
self._active_requests[uid] = (command_id, task_id)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
# Process all active requests - return one token and complete
for uid, (command_id, task_id) in list(self._active_requests.items()):
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=0,
text="hi",
finish_reason="stop",
),
)
)
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def is_distributed(self) -> bool:
return False # Non-distributed mode for testing
class FakeGroup:
"""Fake MLX distributed group for testing."""
def size(self) -> int:
return 1 # Single node (non-distributed)
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a fake "group" (non-None for state machine)
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
)
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
@@ -227,6 +140,13 @@ class EventCollector:
pass
class MockTokenizer:
tool_parser = None
tool_call_start = None
tool_call_end = None
has_tool_calling = False
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -252,14 +172,12 @@ def _run(tasks: Iterable[Task]):
return event_sender.events
def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPatch):
"""Verify chat completion generates tokens, completes, and runner returns to Ready."""
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,
@@ -296,9 +214,7 @@ def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPat
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerRunning(active_requests=1)
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
@@ -313,6 +229,7 @@ def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPat
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)

View File

@@ -11,6 +11,10 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
@@ -36,10 +40,6 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint