Compare commits

...

10 Commits

Author SHA1 Message Date
Chris A
c65320acd3 Fix mlx seed (#1094)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-09 01:40:15 +00:00
Jake Hillion
b9a78f6f3a ci: compute CURRENT_PROJECT_VERSION from semver
Previous Sparkle builds were cut from a different repo with different
build numbers, breaking version ordering. Users aren't receiving updates
because CFBundleVersion values don't reflect the actual version sequence.

Added a step to compute the build version deterministically from semver:
PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR).
Release versions use prerelease=999 to ensure they're always higher than
their prereleases (e.g., 1.0.61 > 1.0.61-alpha.3).

This ensures consistent version ordering across repos, allowing Sparkle
to correctly identify and deliver updates to users.

Test plan:
- Verified formula with test script:

```sh
compute_version() {
  VERSION="$1"
  BASE_VERSION="${VERSION%%-*}"
  MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
  MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
  PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)

  if [[ "$VERSION" == *-* ]]; then
    PRERELEASE_PART="${VERSION#*-}"
    PRERELEASE_NUM="${PRERELEASE_PART##*.}"
    if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
      PRERELEASE_NUM=0
    fi
  else
    PRERELEASE_NUM=999
  fi

  BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
  printf "%-20s -> %12s\n" "$VERSION" "$BUILD_VERSION"
}

compute_version "1.0.61-alpha.2"
compute_version "1.0.61-alpha.3"
compute_version "1.0.61"
compute_version "1.0.62-alpha.1"
compute_version "1.1.0-alpha.1"
compute_version "2.0.0-alpha.1"
compute_version "0.0.0-alpha.0"
compute_version "0.0.1-alpha.1"
compute_version "1.2.3"
compute_version "1.2.3-beta.5"
```

- Output:

```sh
Version              -> Build Number
----------------------------------------
1.0.61-alpha.2       ->   1000061002
1.0.61-alpha.3       ->   1000061003
1.0.61               ->   1000061999
1.0.62-alpha.1       ->   1000062001
1.1.0-alpha.1        ->   1001000001
2.0.0-alpha.1        ->   2000000001
0.0.0-alpha.0        ->            0
0.0.1-alpha.1        ->         1001
1.2.3                ->   1002003999
1.2.3-beta.5         ->   1002003005
```

- Confirmed ordering: alpha.2 < alpha.3 < release < next-alpha
2026-01-08 19:52:33 +01:00
Jake Hillion
8f7f0e893a ci: avoid uploading alpha appcasts
Currently alpha appcasts get uploaded. It turns out these overwrite the
standard appcast, so even though no one will update to the alpha
channel, everyone will miss regular updates while the latest build was
an alpha one.

Ideally we should combine the source of truth for both the alpha and
release channels, but as no one is using the alpha channel for yet let's
stop uploading it for now.

Test plan:

![eyes](https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExeGNwdDk0dmdscjlkZnd6eGxhcjJzdDBsYndmc2t2cnlpZDNxZnZhYSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/gKHGnB1ml0moQdjhEJ/giphy.gif)
2026-01-08 18:52:10 +01:00
Alex Cheema
4759b09d4c Use presigned URLs for bug report uploads (#1109)
## Motivation

Previously we hardcoded AWS credentials into the app.
This is not good practice.

## Changes

Use presigned URLs instead.

## Why It Works

Presigned URLs are an S3 feature for this kind of thing. They provide an
expiring presigned URL with certain permissions. In this case we have a
presigned URL with `s3:PutObject` permission that expires after 5
minutes. The client uses this presigned URL to upload a bug report
instead of using its own credentials to sign a request. This also
simplifies a lot of the Swift code.

## Test Plan

### Manual Testing
On a single MacBook, I downloaded the app and sent a bug report. It
worked and appeared in the bucket.
2026-01-08 17:17:48 +00:00
Alex Cheema
ca680185f3 Display RDMA debug info in macOS app. (#1072)
## Motivation

Often users are running into issues with RDMA. See
https://github.com/exo-explore/exo/issues?q=is%3Aissue%20rdma
Having some debug info in the macOS app will help to debug these issues.

## Changes

Displays output of the following commands in the debug info section of
the macOS app:

1. `rdma_ctl status`
2. `ibv_devices`
3. `ibv_devinfo`

## Why It Works

It displays RDMA debug info in the debug info section of the macOS app.

## Test Plan

### Manual Testing
We need to make a new build of the macOS app and check the output under
the following conditions:

1. No RDMA enabled.
2. RDMA enabled but no devices connected over TB5.
3. RDMA enabled and devices connected over TB5.
2026-01-08 15:17:00 +00:00
Jake Hillion
383309e24e fmt: add typescript formatting
Add typescript auto formatting with Prettier and treefmt-nix. Added a
.prettierrc to useTabs, which isn't the default, to reduce churn. The
rest looks okay and will be checked by CI.

Test plan:
- CI
2026-01-08 13:47:27 +00:00
Jake Hillion
55463a9806 fmt: add swift formatting
Swift code currently has no auto formatting. Add `swift-format` to the
`treefmt-nix` config to get this formatted.

As our existing Swift code uses 4-space formatting instead of the
default 2-space, also adds a custom `.swift-format

Test plan:
- CI
2026-01-08 13:34:45 +00:00
Evan Quiney
56af61fac9 add a server for distributed testing in /tests until we work out a stable solution. (#1098)
## Motivation

Testing multiple devices simultaneously requires coordination, and we
don't necessarily want to run a full EXO to test single components. We
need a mid-scale integration testing framework for distributed tests.

## Changes

Add a simple python server + bash query that runs Jaccl and Ring tests
without constructing a worker/master/networking. The query relies on all
devices being accessible over tailscale, currently.

## Test Plan

Manually tested RDMA + Ring inference on 2 nodes.
2026-01-08 12:50:04 +00:00
Evan Quiney
f76d543d98 We shouldn't fail on an HTTPException in the tier-2 discovery system. (#1104)
## Motivation

Fixed a crash we found

## Changes

try/catch return None if we get an exception instead of crashing exo

## Test Plan

### Manual Testing
Exo launches. Couldn't repro the original case this arose.
2026-01-08 12:43:34 +00:00
Sami Khan
ea841aca37 local network check (#1103)
## Motivation

After machine restart, macOS local network permission can appear enabled
in System Settings but not actually work. EXO fails to discover other
machines, and the only fix is manually toggling the permission off/on
and relaunching. Users had no way to know this was happening.

## Changes

- Added LocalNetworkChecker service that detects if local network access
is actually functional
- Added warning banner with instructions and "Open Settings" button when
blocked
- Added NSLocalNetworkUsageDescription and NSBonjourServices to
Info.plist (required by macOS)

<img width="386" height="712" alt="image"
src="https://github.com/user-attachments/assets/c6fc873d-2c6a-4c9b-89cb-f7bc7322e25b"
/>

## Why It Works

Uses NWConnection to UDP multicast address 224.0.0.251:5353 (mDNS),
which is subject to the app's actual TCC permission state. Other
approaches (NWBrowser, dns-sd subprocess) either require additional
entitlements or run with their own permissions, giving false results.

## Test Plan

### Manual Testing
Hardware: MacBook Pro
  - Toggle local network OFF in System Settings → warning banner appears
  - Toggle local network ON → warning disappears
  - Verified detection correctly reflects actual permission state

### Automated Testing
N/A
2026-01-08 12:24:46 +00:00
34 changed files with 1814 additions and 903 deletions

View File

@@ -18,6 +18,7 @@ jobs:
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
AWS_REGION: ${{ secrets.AWS_REGION }}
EXO_BUILD_NUMBER: ${{ github.run_number }}
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
@@ -47,6 +48,32 @@ jobs:
fi
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
- name: Compute build version from semver
run: |
VERSION="$RELEASE_VERSION"
# Extract major.minor.patch (strip prerelease suffix)
BASE_VERSION="${VERSION%%-*}"
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
if [[ "$VERSION" == *-* ]]; then
PRERELEASE_PART="${VERSION#*-}"
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
# Default to 0 if not a number
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
PRERELEASE_NUM=0
fi
else
PRERELEASE_NUM=999
fi
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
echo "Computed build version: $BUILD_VERSION from $VERSION"
- name: Ensure tag commit is on main
if: github.ref_type == 'tag'
run: |
@@ -162,11 +189,12 @@ jobs:
-configuration Release \
-derivedDataPath build \
MARKETING_VERSION="$RELEASE_VERSION" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
EXO_BUILD_TAG="$RELEASE_VERSION" \
EXO_BUILD_COMMIT="$GITHUB_SHA" \
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
mkdir -p ../../output
@@ -294,5 +322,5 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
if [[ "$IS_ALPHA" != "true" ]]; then
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache

3
.prettierrc Normal file
View File

@@ -0,0 +1,3 @@
{
"useTabs": true
}

6
.swift-format Normal file
View File

@@ -0,0 +1,6 @@
{
"version": 1,
"indentation": {
"spaces": 4
}
}

View File

@@ -12,6 +12,7 @@ struct ContentView: View {
@EnvironmentObject private var controller: ExoProcessController
@EnvironmentObject private var stateService: ClusterStateService
@EnvironmentObject private var networkStatusService: NetworkStatusService
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
@EnvironmentObject private var updater: SparkleUpdater
@State private var focusedNode: NodeViewModel?
@State private var deletingInstanceIDs: Set<String> = []
@@ -26,6 +27,9 @@ struct ContentView: View {
var body: some View {
VStack(alignment: .leading, spacing: 12) {
statusSection
if shouldShowLocalNetworkWarning {
localNetworkWarningBanner
}
if shouldShowClusterDetails {
Divider()
overviewSection
@@ -40,6 +44,7 @@ struct ContentView: View {
}
.animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)
.animation(.easeInOut(duration: 0.3), value: shouldShowInstances)
.animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)
.padding()
.frame(width: 340)
.onAppear {
@@ -49,9 +54,62 @@ struct ContentView: View {
}
}
private var shouldShowLocalNetworkWarning: Bool {
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}
return false
}
private var localNetworkWarningBanner: some View {
VStack(alignment: .leading, spacing: 6) {
HStack(spacing: 6) {
Image(systemName: "exclamationmark.triangle.fill")
.foregroundColor(.orange)
Text("Local Network Access Issue")
.font(.caption)
.fontWeight(.semibold)
}
Text(
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
)
.font(.caption2)
.foregroundColor(.secondary)
.fixedSize(horizontal: false, vertical: true)
Button {
openLocalNetworkSettings()
} label: {
Text("Open Settings")
.font(.caption2)
}
.buttonStyle(.bordered)
.controlSize(.small)
}
.padding(8)
.background(
RoundedRectangle(cornerRadius: 8)
.fill(Color.orange.opacity(0.1))
)
.overlay(
RoundedRectangle(cornerRadius: 8)
.stroke(Color.orange.opacity(0.3), lineWidth: 1)
)
}
private func openLocalNetworkSettings() {
// Open Privacy & Security settings - Local Network section
if let url = URL(
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
{
NSWorkspace.shared.open(url)
}
}
private var topologySection: some View {
Group {
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
if let topology = stateService.latestSnapshot?.topologyViewModel(
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
{
TopologyMiniView(topology: topology)
}
}
@@ -85,8 +143,10 @@ struct ContentView: View {
VStack(alignment: .leading, spacing: 4) {
HStack {
VStack(alignment: .leading) {
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
.font(.headline)
Text(
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
)
.font(.headline)
Text("Memory")
.font(.caption)
.foregroundColor(.secondary)
@@ -210,7 +270,9 @@ struct ContentView: View {
}
}
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
-> some View
{
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
}
@@ -241,9 +303,12 @@ struct ContentView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(
isExpanded.wrappedValue ? "Hide" : "Show All",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -394,6 +459,7 @@ struct ContentView: View {
.font(.caption2)
.foregroundColor(thunderboltStatusColor)
interfaceIpList
rdmaStatusView
sendBugReportButton
.padding(.top, 6)
}
@@ -403,6 +469,52 @@ struct ContentView: View {
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
}
private var rdmaStatusView: some View {
let rdma = networkStatusService.status.rdmaStatus
return VStack(alignment: .leading, spacing: 1) {
Text("RDMA: \(rdmaStatusText(rdma))")
.font(.caption2)
.foregroundColor(rdmaStatusColor(rdma))
if !rdma.devices.isEmpty {
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
.font(.caption2)
.foregroundColor(.secondary)
}
if !rdma.activePorts.isEmpty {
Text(" Active Ports:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(rdma.activePorts, id: \.device) { port in
Text(" \(port.device) port \(port.port): \(port.state)")
.font(.caption2)
.foregroundColor(.green)
}
}
}
}
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
switch rdma.rdmaCtlEnabled {
case .some(true):
return "Enabled"
case .some(false):
return "Disabled"
case nil:
return rdma.devices.isEmpty ? "Not Available" : "Available"
}
}
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
switch rdma.rdmaCtlEnabled {
case .some(true):
return .green
case .some(false):
return .orange
case nil:
return rdma.devices.isEmpty ? .secondary : .green
}
}
private var sendBugReportButton: some View {
VStack(alignment: .leading, spacing: 4) {
Button {
@@ -536,4 +648,3 @@ private struct HoverButton: View {
.onHover { isHovering = $0 }
}
}

View File

@@ -8,9 +8,9 @@
import AppKit
import CoreImage
import CoreImage.CIFilterBuiltins
import ServiceManagement
import Sparkle
import SwiftUI
import ServiceManagement
import UserNotifications
import os.log
@@ -19,6 +19,7 @@ struct EXOApp: App {
@StateObject private var controller: ExoProcessController
@StateObject private var stateService: ClusterStateService
@StateObject private var networkStatusService: NetworkStatusService
@StateObject private var localNetworkChecker: LocalNetworkChecker
@StateObject private var updater: SparkleUpdater
private let terminationObserver: TerminationObserver
private let ciContext = CIContext(options: nil)
@@ -37,9 +38,13 @@ struct EXOApp: App {
_stateService = StateObject(wrappedValue: service)
let networkStatus = NetworkStatusService()
_networkStatusService = StateObject(wrappedValue: networkStatus)
let localNetwork = LocalNetworkChecker()
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
_updater = StateObject(wrappedValue: updater)
enableLaunchAtLoginIfNeeded()
NetworkSetupHelper.ensureLaunchDaemonInstalled()
// Check local network access BEFORE launching exo
localNetwork.check()
controller.scheduleLaunch(after: 15)
service.startPolling()
networkStatus.startPolling()
@@ -51,6 +56,7 @@ struct EXOApp: App {
.environmentObject(controller)
.environmentObject(stateService)
.environmentObject(networkStatusService)
.environmentObject(localNetworkChecker)
.environmentObject(updater)
} label: {
menuBarIcon
@@ -107,7 +113,7 @@ struct EXOApp: App {
filter.contrast = 0.9
guard let output = filter.outputImage,
let rendered = ciContext.createCGImage(output, from: output.extent)
let rendered = ciContext.createCGImage(output, from: output.extent)
else {
return nil
}
@@ -120,7 +126,8 @@ struct EXOApp: App {
do {
try SMAppService.mainApp.register()
} catch {
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
Logger().error(
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
}
@@ -145,7 +152,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
controller.updater.automaticallyChecksForUpdates = true
controller.updater.automaticallyDownloadsUpdates = false
controller.updater.updateCheckInterval = 900 // 15 minutes
controller.updater.updateCheckInterval = 900 // 15 minutes
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
controller?.updater.checkForUpdatesInBackground()
}
@@ -212,7 +219,8 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
func userNotificationCenter(
_ center: UNUserNotificationCenter,
willPresent notification: UNNotification,
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
Void
) {
completionHandler([.banner, .list, .sound])
}

View File

@@ -31,7 +31,8 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}() {
}()
{
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
@@ -221,7 +222,9 @@ final class ExoProcessController: ObservableObject {
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
return tag
}
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
!short.isEmpty
{
return short
}
return "dev"

View File

@@ -8,5 +8,15 @@
<string>$(EXO_BUILD_TAG)</string>
<key>EXOBuildCommit</key>
<string>$(EXO_BUILD_COMMIT)</string>
<key>EXOBugReportPresignedUrlEndpoint</key>
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
<key>NSLocalNetworkUsageDescription</key>
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
<key>NSBonjourServices</key>
<array>
<string>_p2p._tcp</string>
<string>_p2p._udp</string>
<string>_libp2p._udp</string>
</array>
</dict>
</plist>

View File

@@ -16,10 +16,13 @@ struct ClusterState: Decodable {
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
let rawDownloads =
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
}
@@ -41,7 +44,8 @@ private struct TaggedInstance: Decodable {
let payloads = try container.decode([String: ClusterInstancePayload].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
)
}
self.instance = ClusterInstance(
@@ -77,7 +81,8 @@ struct RunnerStatusSummary: Decodable {
let payloads = try container.decode([String: RunnerStatusDetail].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
)
}
self.status = entry.key
@@ -257,7 +262,9 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
func promptPreview() -> String? {
guard let messages else { return nil }
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
if let userMessage = messages.last(where: {
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
}) {
return userMessage.content
}
return messages.last?.content
@@ -365,5 +372,3 @@ extension ClusterState {
func availableModels() -> [ModelOption] { [] }
}

View File

@@ -1,4 +1,3 @@
import CryptoKit
import Foundation
struct BugReportOutcome: Equatable {
@@ -7,17 +6,17 @@ struct BugReportOutcome: Equatable {
}
enum BugReportError: LocalizedError {
case missingCredentials
case invalidEndpoint
case presignedUrlFailed(String)
case uploadFailed(String)
case collectFailed(String)
var errorDescription: String? {
switch self {
case .missingCredentials:
return "Bug report upload credentials are not set."
case .invalidEndpoint:
return "Bug report endpoint is invalid."
case .presignedUrlFailed(let message):
return "Failed to get presigned URLs: \(message)"
case .uploadFailed(let message):
return "Bug report upload failed: \(message)"
case .collectFailed(let message):
@@ -27,11 +26,13 @@ enum BugReportError: LocalizedError {
}
struct BugReportService {
struct AWSConfig {
let accessKey: String
let secretKey: String
let region: String
let bucket: String
private struct PresignedUrlsRequest: Codable {
let keys: [String]
}
private struct PresignedUrlsResponse: Codable {
let urls: [String: String]
let expiresIn: Int?
}
func sendReport(
@@ -39,9 +40,9 @@ struct BugReportService {
now: Date = Date(),
isManual: Bool = false
) async throws -> BugReportOutcome {
let credentials = try loadCredentials()
let timestamp = ISO8601DateFormatter().string(from: now)
let prefix = "reports/\(timestamp)/"
let timestamp = Self.runTimestampString(now)
let dayPrefix = Self.dayPrefixString(now)
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
let logData = readLog()
let ifconfigText = try await captureIfconfig()
@@ -66,28 +67,82 @@ struct BugReportService {
("\(prefix)exo.log", logData),
("\(prefix)state.json", stateData),
("\(prefix)events.json", eventsData),
("\(prefix)report.json", reportJSON)
("\(prefix)report.json", reportJSON),
]
let uploader = try S3Uploader(config: credentials)
for item in uploads {
guard let data = item.data else { continue }
try await uploader.upload(
objectPath: item.path,
body: data
)
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
guard let body = item.data else { return nil }
return (key: item.path, body: body)
}
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
guard !uploadItems.isEmpty else {
return BugReportOutcome(success: false, message: "No data to upload")
}
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
for item in uploadItems {
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
}
try await uploadToPresignedUrl(url: url, body: item.body)
}
return BugReportOutcome(
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
}
private func loadCredentials() throws -> AWSConfig {
return AWSConfig(
accessKey: "AKIAYEKP5EMXTOBYDGHX",
secretKey: "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",
region: "us-east-1",
bucket: "exo-bug-reports"
)
private static func dayPrefixString(_ date: Date) -> String {
var calendar = Calendar(identifier: .gregorian)
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
let components = calendar.dateComponents([.year, .month, .day], from: date)
let year = components.year ?? 0
let month = components.month ?? 0
let day = components.day ?? 0
return String(format: "%04d/%02d/%02d", year, month, day)
}
private static func runTimestampString(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.locale = Locale(identifier: "en_US_POSIX")
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
return formatter.string(from: date)
}
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
-> [String: String]
{
guard
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
as? String
else {
throw BugReportError.invalidEndpoint
}
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
else {
throw BugReportError.invalidEndpoint
}
var request = URLRequest(url: endpoint)
request.httpMethod = "POST"
request.timeoutInterval = 10
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
let encoder = JSONEncoder()
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.presignedUrlFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
}
let decoder = JSONDecoder()
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
return decoded.urls
}
private func readLog() -> Data? {
@@ -100,7 +155,8 @@ struct BugReportService {
private func captureIfconfig() async throws -> String {
let result = runCommand(["/sbin/ifconfig"])
guard result.exitCode == 0 else {
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
throw BugReportError.collectFailed(
result.error.isEmpty ? "ifconfig failed" : result.error)
}
return result.output
}
@@ -108,12 +164,23 @@ struct BugReportService {
private func readDebugInfo() -> DebugInfo {
DebugInfo(
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
interfaces: readInterfaces()
interfaces: readInterfaces(),
rdma: readRDMADebugInfo()
)
}
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
DebugInfo.RDMADebugInfo(
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
)
}
private func readThunderboltBridgeDisabled() -> Bool? {
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
let result = runCommand([
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased()
if output.contains("enabled") {
@@ -156,7 +223,8 @@ struct BugReportService {
request.timeoutInterval = 5
do {
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
else {
return nil
}
return data
@@ -165,6 +233,36 @@ struct BugReportService {
}
}
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
let maxAttempts = 2
var lastError: Error?
for attempt in 1...maxAttempts {
do {
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.timeoutInterval = 30
let (_, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.uploadFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
}
return
} catch {
lastError = error
if attempt < maxAttempts {
try await Task.sleep(nanoseconds: 400_000_000)
}
}
}
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
}
private func makeReportJson(
timestamp: String,
hostName: String,
@@ -182,7 +280,7 @@ struct BugReportService {
"system": system,
"exo_version": exo.version as Any,
"exo_commit": exo.commit as Any,
"report_type": isManual ? "manual" : "automated"
"report_type": isManual ? "manual" : "automated",
]
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
}
@@ -213,10 +311,13 @@ struct BugReportService {
let user = safeRunCommand(["/usr/bin/whoami"])
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
let uptime = safeRunCommand(["/usr/bin/uptime"])
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
let diskRoot = safeRunCommand([
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
])
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
let interfacesAndIPs = interfacesList?
let interfacesAndIPs =
interfacesList?
.split(whereSeparator: { $0 == " " || $0 == "\n" })
.compactMap { iface -> [String: Any]? in
let name = String(iface)
@@ -227,7 +328,8 @@ struct BugReportService {
} ?? []
let wifiSSID: String?
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
let airportPath =
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
if FileManager.default.isExecutableFile(atPath: airportPath) {
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
} else {
@@ -255,7 +357,7 @@ struct BugReportService {
"disk_root": diskRoot as Any,
"interfaces_and_ips": interfacesAndIPs,
"ipconfig_getiflist": interfacesList as Any,
"wifi_ssid": wifiSSID as Any
"wifi_ssid": wifiSSID as Any,
]
}
@@ -313,7 +415,8 @@ struct BugReportService {
for line in airportOutput.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("SSID:") {
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
in: .whitespaces)
}
}
return nil
@@ -350,6 +453,7 @@ struct BugReportService {
private struct DebugInfo {
let thunderboltBridgeDisabled: Bool?
let interfaces: [InterfaceStatus]
let rdma: RDMADebugInfo
struct InterfaceStatus {
let name: String
@@ -358,7 +462,21 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"name": name,
"ip": ip as Any
"ip": ip as Any,
]
}
}
struct RDMADebugInfo {
let rdmaCtlStatus: String?
let ibvDevices: String?
let ibvDevinfo: String?
func toDictionary() -> [String: Any] {
[
"rdma_ctl_status": rdmaCtlStatus as Any,
"ibv_devices": ibvDevices as Any,
"ibv_devinfo": ibvDevinfo as Any,
]
}
}
@@ -366,7 +484,8 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
"interfaces": interfaces.map { $0.toDictionary() }
"interfaces": interfaces.map { $0.toDictionary() },
"rdma": rdma.toDictionary(),
]
}
}
@@ -376,163 +495,3 @@ private struct CommandResult {
let output: String
let error: String
}
private struct S3Uploader {
let config: BugReportService.AWSConfig
init(config: BugReportService.AWSConfig) throws {
self.config = config
}
func upload(objectPath: String, body: Data) async throws {
let host = "\(config.bucket).s3.amazonaws.com"
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
throw BugReportError.invalidEndpoint
}
let now = Date()
let amzDate = awsTimestamp(now)
let dateStamp = dateStamp(now)
let payloadHash = sha256Hex(body)
let headers = [
"host": host,
"x-amz-content-sha256": payloadHash,
"x-amz-date": amzDate
]
let canonicalRequest = buildCanonicalRequest(
method: "PUT",
url: url,
headers: headers,
payloadHash: payloadHash
)
let stringToSign = buildStringToSign(
amzDate: amzDate,
dateStamp: dateStamp,
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
)
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
let authorization = """
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
"""
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
request.setValue(host, forHTTPHeaderField: "Host")
request.setValue(authorization, forHTTPHeaderField: "Authorization")
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
_ = data // ignore response body for UX
throw BugReportError.uploadFailed("HTTP status \(statusText)")
}
}
private func buildCanonicalRequest(
method: String,
url: URL,
headers: [String: String],
payloadHash: String
) -> String {
let canonicalURI = encodePath(url.path)
let canonicalQuery = url.query ?? ""
let sortedHeaders = headers.sorted { $0.key < $1.key }
let canonicalHeaders = sortedHeaders
.map { "\($0.key.lowercased()):\($0.value)\n" }
.joined()
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
return [
method,
canonicalURI,
canonicalQuery,
canonicalHeaders,
signedHeaders,
payloadHash
].joined(separator: "\n")
}
private func encodePath(_ path: String) -> String {
return path
.split(separator: "/")
.map { segment in
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
}
.joined(separator: "/")
.prependSlashIfNeeded()
}
private func buildStringToSign(
amzDate: String,
dateStamp: String,
canonicalRequestHash: String
) -> String {
"""
AWS4-HMAC-SHA256
\(amzDate)
\(dateStamp)/\(config.region)/s3/aws4_request
\(canonicalRequestHash)
"""
}
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
let kRegion = hmac(key: kDate, data: Data(region.utf8))
let kService = hmac(key: kRegion, data: Data(service.utf8))
return hmac(key: kService, data: Data("aws4_request".utf8))
}
private func hmac(key: Data, data: Data) -> Data {
let keySym = SymmetricKey(data: key)
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
return Data(mac)
}
private func hmacHex(key: Data, data: Data) -> String {
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
}
private func sha256Hex(_ data: Data) -> String {
let digest = SHA256.hash(data: data)
return digest.compactMap { String(format: "%02x", $0) }.joined()
}
private func awsTimestamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private func dateStamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private static let rfc3986: CharacterSet = {
var set = CharacterSet.alphanumerics
set.insert(charactersIn: "-._~")
return set
}()
}
private extension String {
func prependSlashIfNeeded() -> String {
if hasPrefix("/") {
return self
}
return "/" + self
}
}

View File

@@ -57,7 +57,9 @@ final class ClusterStateService: ObservableObject {
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
return
}
if let nodeId = try? decoder.decode(String.self, from: data) {
@@ -113,7 +115,9 @@ final class ClusterStateService: ObservableObject {
}
}
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
async
{
do {
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
request.httpMethod = "POST"
@@ -122,7 +126,7 @@ final class ClusterStateService: ObservableObject {
"model_id": modelId,
"sharding": sharding,
"instance_meta": instanceMeta,
"min_nodes": minNodes
"min_nodes": minNodes,
]
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
let (_, response) = try await session.data(for: request)
@@ -143,7 +147,9 @@ final class ClusterStateService: ObservableObject {
do {
let url = baseURL.appendingPathComponent("models")
let (data, response) = try await session.data(from: url)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
throw URLError(.badServerResponse)
}
let list = try decoder.decode(ModelListResponse.self, from: data)

View File

@@ -0,0 +1,150 @@
import Foundation
import Network
import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
case unknown
case checking
case working
case notWorking(reason: String)
var isHealthy: Bool {
if case .working = self { return true }
return false
}
var displayText: String {
switch self {
case .unknown:
return "Unknown"
case .checking:
return "Checking..."
case .working:
return "Working"
case .notWorking(let reason):
return reason
}
}
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
let result = await self.performCheck()
self.status = result
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
// mDNS multicast address - same as libp2p uses for peer discovery
let host = NWEndpoint.Host("224.0.0.251")
let port = NWEndpoint.Port(integerLiteral: 5353)
let params = NWParameters.udp
params.allowLocalEndpointReuse = true
let conn = NWConnection(host: host, port: port, using: params)
connection = conn
return await withCheckedContinuation { continuation in
var hasResumed = false
let lock = NSLock()
let resumeOnce: (Status) -> Void = { status in
lock.lock()
defer { lock.unlock() }
guard !hasResumed else { return }
hasResumed = true
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
case .waiting(let error):
let errorStr = "\(error)"
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|| errorStr.contains("permission") || errorStr.contains("denied")
{
resumeOnce(.notWorking(reason: "Permission denied"))
} else {
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
}
case .cancelled, .setup, .preparing:
break
@unknown default:
break
}
}
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:
resumeOnce(.working)
case .waiting, .preparing, .setup:
resumeOnce(.notWorking(reason: "Timeout (may be blocked)"))
default:
resumeOnce(.notWorking(reason: "Timeout"))
}
}
}
}
func stop() {
checkTask?.cancel()
checkTask = nil
connection?.cancel()
connection = nil
}
}

View File

@@ -5,61 +5,62 @@ import os.log
enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
#!/usr/bin/env bash
set -euo pipefail
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
Task.detached {
@@ -70,7 +71,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
@@ -82,7 +85,8 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
guard scriptExists, plistExists else { return false }
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
@@ -92,7 +96,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
@@ -105,58 +111,59 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript = script
let escapedScript =
script
.replacingOccurrences(of: "\\", with: "\\\\")
.replacingOccurrences(of: "\"", with: "\\\"")
let appleScriptSource = """
do shell script "\(escapedScript)" with administrator privileges
"""
do shell script "\(escapedScript)" with administrator privileges
"""
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
throw NetworkSetupError.scriptCreationFailed

View File

@@ -35,14 +35,34 @@ struct NetworkStatus: Equatable {
let thunderboltBridgeState: ThunderboltState?
let bridgeInactive: Bool?
let interfaceStatuses: [InterfaceIpStatus]
let rdmaStatus: RDMAStatus
static let empty = NetworkStatus(
thunderboltBridgeState: nil,
bridgeInactive: nil,
interfaceStatuses: []
interfaceStatuses: [],
rdmaStatus: .empty
)
}
struct RDMAStatus: Equatable {
let rdmaCtlEnabled: Bool?
let devices: [String]
let activePorts: [RDMAPort]
var isAvailable: Bool {
rdmaCtlEnabled == true || !devices.isEmpty
}
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
}
struct RDMAPort: Equatable {
let device: String
let port: String
let state: String
}
struct InterfaceIpStatus: Equatable {
let interfaceName: String
let ipAddress: String?
@@ -59,10 +79,79 @@ private struct NetworkStatusFetcher {
NetworkStatus(
thunderboltBridgeState: readThunderboltBridgeState(),
bridgeInactive: readBridgeInactive(),
interfaceStatuses: readInterfaceStatuses()
interfaceStatuses: readInterfaceStatuses(),
rdmaStatus: readRDMAStatus()
)
}
private func readRDMAStatus() -> RDMAStatus {
let rdmaCtlEnabled = readRDMACtlEnabled()
let devices = readRDMADevices()
let activePorts = readRDMAActivePorts()
return RDMAStatus(
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
}
private func readRDMACtlEnabled() -> Bool? {
let result = runCommand(["rdma_ctl", "status"])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
if output.contains("enabled") {
return true
}
if output.contains("disabled") {
return false
}
return nil
}
private func readRDMADevices() -> [String] {
let result = runCommand(["ibv_devices"])
guard result.exitCode == 0 else { return [] }
var devices: [String] = []
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|| trimmed.isEmpty
{
continue
}
let parts = trimmed.split(separator: " ", maxSplits: 1)
if let deviceName = parts.first {
devices.append(String(deviceName))
}
}
return devices
}
private func readRDMAActivePorts() -> [RDMAPort] {
let result = runCommand(["ibv_devinfo"])
guard result.exitCode == 0 else { return [] }
var ports: [RDMAPort] = []
var currentDevice: String?
var currentPort: String?
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("hca_id:") {
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("port:") {
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("state:") {
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
in: .whitespaces)
if let device = currentDevice, let port = currentPort {
if state.lowercased().contains("active") {
ports.append(RDMAPort(device: device, port: port, state: state))
}
}
}
}
return ports
}
private func readThunderboltBridgeState() -> ThunderboltState? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
guard result.exitCode == 0 else {
@@ -85,10 +174,11 @@ private struct NetworkStatusFetcher {
private func readBridgeInactive() -> Bool? {
let result = runCommand(["ifconfig", "bridge0"])
guard result.exitCode == 0 else { return nil }
guard let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
guard
let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
else {
return nil
}
@@ -171,4 +261,3 @@ private struct NetworkStatusFetcher {
)
}
}

View File

@@ -107,10 +107,13 @@ extension ClusterState {
let nodeToRunner = instance.shardAssignments.nodeToRunner
let nodeIds = Array(nodeToRunner.keys)
let runnerIds = Array(nodeToRunner.values)
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
let nodeNames = nodeIds.compactMap {
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
}
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
let state = InstanceViewModel.State(
statuses: statuses, hasActiveDownload: downloadProgress != nil)
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
.sorted(by: { $0.sortPriority < $1.sortPriority })
.map { InstanceTaskViewModel(task: $0) }
@@ -165,8 +168,8 @@ extension ClusterState {
}
}
private extension InstanceViewModel.State {
init(statuses: [String], hasActiveDownload: Bool = false) {
extension InstanceViewModel.State {
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
if statuses.contains(where: { $0.contains("failed") }) {
self = .failed
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
@@ -243,4 +246,3 @@ extension InstanceTaskViewModel {
self.parameters = task.parameters
}
}

View File

@@ -87,7 +87,9 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
guard !allNodes.isEmpty else { return nil }
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
@@ -106,18 +108,24 @@ extension ClusterState {
}
// Rotate so the local node (from /node_id API) is first
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
if let localId = localNodeId,
let index = orderedNodes.firstIndex(where: { $0.id == localId })
{
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
}
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }
return TopologyEdgeViewModel(
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edges = Set(edgesArray)
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
return TopologyViewModel(
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
}
}

View File

@@ -20,8 +20,8 @@ struct InstanceRowView: View {
if let progress = instance.downloadProgress {
downloadStatusView(progress: progress)
} else {
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
}
if let progress = instance.downloadProgress {
GeometryReader { geometry in
@@ -97,7 +97,8 @@ struct InstanceRowView: View {
.font(.caption)
.fontWeight(.semibold)
if let subtitle = task.subtitle,
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
{
Text(subtitle)
.font(.caption2)
.foregroundColor(.secondary)
@@ -234,9 +235,12 @@ struct InstanceRowView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(
isExpanded.wrappedValue ? "Hide" : "Show",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -311,7 +315,9 @@ struct InstanceRowView: View {
}
@ViewBuilder
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
private func detailRow(
icon: String? = nil, title: String, value: String, tint: Color = .secondary
) -> some View {
HStack(alignment: .firstTextBaseline, spacing: 6) {
if let icon {
Image(systemName: icon)
@@ -329,4 +335,3 @@ struct InstanceRowView: View {
}
}
}

View File

@@ -32,4 +32,3 @@ struct NodeDetailView: View {
}
}
}

View File

@@ -28,4 +28,3 @@ struct NodeRowView: View {
.padding(.vertical, 4)
}
}

View File

@@ -76,30 +76,33 @@ struct TopologyMiniView: View {
private func connectionLines(in size: CGSize) -> some View {
let positions = positionedNodes(in: size)
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
let positionById = Dictionary(
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
return Canvas { context, _ in
guard !topology.edges.isEmpty else { return }
let nodeRadius: CGFloat = 32
let arrowLength: CGFloat = 10
let arrowSpread: CGFloat = .pi / 7
for edge in topology.edges {
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
else { continue }
let dx = end.x - start.x
let dy = end.y - start.y
let distance = max(CGFloat(hypot(dx, dy)), 1)
let ux = dx / distance
let uy = dy / distance
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedStart = CGPoint(
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
var linePath = Path()
linePath.move(to: adjustedStart)
linePath.addLine(to: adjustedEnd)
context.stroke(
context.stroke(
linePath,
with: .color(.secondary.opacity(0.3)),
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
let angle = atan2(uy, ux)
let tip = adjustedEnd
@@ -168,5 +171,3 @@ private struct NodeGlyphView: View {
.frame(width: 95)
}
}

View File

@@ -6,6 +6,7 @@
//
import Testing
@testable import EXO
struct EXOTests {

View File

@@ -11,4 +11,3 @@ declare global {
}
export {};

View File

@@ -1,8 +1,7 @@
export { default as TopologyGraph } from './TopologyGraph.svelte';
export { default as ChatForm } from './ChatForm.svelte';
export { default as ChatMessages } from './ChatMessages.svelte';
export { default as ChatAttachments } from './ChatAttachments.svelte';
export { default as ChatSidebar } from './ChatSidebar.svelte';
export { default as ModelCard } from './ModelCard.svelte';
export { default as MarkdownContent } from './MarkdownContent.svelte';
export { default as TopologyGraph } from "./TopologyGraph.svelte";
export { default as ChatForm } from "./ChatForm.svelte";
export { default as ChatMessages } from "./ChatMessages.svelte";
export { default as ChatAttachments } from "./ChatAttachments.svelte";
export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";

View File

File diff suppressed because it is too large Load Diff

View File

@@ -13,55 +13,124 @@ export interface ChatUploadedFile {
}
export interface ChatAttachment {
type: 'image' | 'text' | 'pdf' | 'audio';
type: "image" | "text" | "pdf" | "audio";
name: string;
content?: string;
base64Url?: string;
mimeType?: string;
}
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
export const IMAGE_EXTENSIONS = [
".jpg",
".jpeg",
".png",
".gif",
".webp",
".svg",
];
export const IMAGE_MIME_TYPES = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/svg+xml",
];
export const TEXT_EXTENSIONS = [
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
".txt",
".md",
".json",
".xml",
".yaml",
".yml",
".csv",
".log",
".js",
".ts",
".jsx",
".tsx",
".py",
".java",
".cpp",
".c",
".h",
".css",
".html",
".htm",
".sql",
".sh",
".bat",
".rs",
".go",
".rb",
".php",
".swift",
".kt",
".scala",
".r",
".dart",
".vue",
".svelte",
];
export const TEXT_MIME_TYPES = [
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
'application/json', 'application/xml', 'text/xml', 'application/javascript',
'text/javascript', 'application/typescript'
"text/plain",
"text/markdown",
"text/csv",
"text/html",
"text/css",
"application/json",
"application/xml",
"text/xml",
"application/javascript",
"text/javascript",
"application/typescript",
];
export const PDF_EXTENSIONS = ['.pdf'];
export const PDF_MIME_TYPES = ['application/pdf'];
export const PDF_EXTENSIONS = [".pdf"];
export const PDF_MIME_TYPES = ["application/pdf"];
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
export const AUDIO_MIME_TYPES = [
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/mp4",
];
/**
* Get file category based on MIME type and extension
*/
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
return 'image';
export function getFileCategory(
mimeType: string,
fileName: string,
): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
if (
IMAGE_MIME_TYPES.includes(mimeType) ||
IMAGE_EXTENSIONS.includes(extension)
) {
return "image";
}
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
return 'pdf';
return "pdf";
}
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
return 'audio';
if (
AUDIO_MIME_TYPES.includes(mimeType) ||
AUDIO_EXTENSIONS.includes(extension)
) {
return "audio";
}
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
return 'text';
if (
TEXT_MIME_TYPES.includes(mimeType) ||
TEXT_EXTENSIONS.includes(extension) ||
mimeType.startsWith("text/")
) {
return "text";
}
return 'unknown';
return "unknown";
}
/**
@@ -69,36 +138,36 @@ export function getFileCategory(mimeType: string, fileName: string): FileCategor
*/
export function getAcceptString(categories: FileCategory[]): string {
const accepts: string[] = [];
for (const category of categories) {
switch (category) {
case 'image':
case "image":
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
break;
case 'text':
case "text":
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
break;
case 'pdf':
case "pdf":
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
break;
case 'audio':
case "audio":
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
break;
}
}
return accepts.join(',');
return accepts.join(",");
}
/**
* Format file size for display
*/
export function formatFileSize(bytes: number): string {
if (bytes === 0) return '0 B';
if (bytes === 0) return "0 B";
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const sizes = ["B", "KB", "MB", "GB"];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
}
/**
@@ -128,42 +197,44 @@ export function readFileAsText(file: File): Promise<string> {
/**
* Process uploaded files into ChatUploadedFile format
*/
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
export async function processUploadedFiles(
files: File[],
): Promise<ChatUploadedFile[]> {
const results: ChatUploadedFile[] = [];
for (const file of files) {
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
const id =
Date.now().toString() + Math.random().toString(36).substring(2, 9);
const category = getFileCategory(file.type, file.name);
const base: ChatUploadedFile = {
id,
name: file.name,
size: file.size,
type: file.type,
file
file,
};
try {
if (category === 'image') {
if (category === "image") {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else if (category === 'text' || category === 'unknown') {
} else if (category === "text" || category === "unknown") {
const textContent = await readFileAsText(file);
results.push({ ...base, textContent });
} else if (category === 'pdf') {
} else if (category === "pdf") {
results.push(base);
} else if (category === 'audio') {
} else if (category === "audio") {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else {
results.push(base);
}
} catch (error) {
console.error('Error processing file:', file.name, error);
console.error("Error processing file:", file.name, error);
results.push(base);
}
}
return results;
}

View File

@@ -1,16 +1,15 @@
import tailwindcss from '@tailwindcss/vite';
import { sveltekit } from '@sveltejs/kit/vite';
import { defineConfig } from 'vite';
import tailwindcss from "@tailwindcss/vite";
import { sveltekit } from "@sveltejs/kit/vite";
import { defineConfig } from "vite";
export default defineConfig({
plugins: [tailwindcss(), sveltekit()],
server: {
proxy: {
'/v1': 'http://localhost:52415',
'/state': 'http://localhost:52415',
'/models': 'http://localhost:52415',
'/instance': 'http://localhost:52415'
}
}
"/v1": "http://localhost:52415",
"/state": "http://localhost:52415",
"/models": "http://localhost:52415",
"/instance": "http://localhost:52415",
},
},
});

View File

@@ -42,11 +42,22 @@
};
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix";
programs.ruff-format.enable = true;
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
programs.rustfmt.enable = true;
programs.rustfmt.package = (fenixToolchain system).rustfmt;
programs.nixpkgs-fmt.enable = true;
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
};
rustfmt = {
enable = true;
package = (fenixToolchain system).rustfmt;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format.enable = true;
};
};
in
{

View File

@@ -3,6 +3,7 @@ from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
@@ -47,7 +48,6 @@ def maybe_quantize_kv_cache(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -70,6 +70,9 @@ def warmup_inference(
model=model,
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
logger.info("Generating warmup tokens")
for _r in stream_generate(
model=model,
@@ -115,7 +118,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
task: ChatCompletionTaskParams,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
@@ -125,6 +127,9 @@ def mlx_generate(
# Currently we support chat-completion tasks only.
logger.info(f"task_params: {task}")
if task.seed is not None:
mx.random.seed(task.seed)
prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=task,
@@ -138,6 +143,11 @@ def mlx_generate(
eos_ids = eos_ids_from_tokenizer(tokenizer)
logits_processors = [ban_token_ids(eos_ids)]
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,

View File

@@ -3,11 +3,10 @@ import os
import resource
import time
from pathlib import Path
from typing import Any, Callable, cast
from typing import Any, cast
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -176,11 +175,7 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
# TODO: pass temperature
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
logger.info("Created a sampler")
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
@@ -201,7 +196,7 @@ def load_mlx_items(
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
return cast(Model, model), tokenizer, sampler
return cast(Model, model), tokenizer
def shard_and_load(

View File

@@ -6,7 +6,7 @@ from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import MpReceiver, MpSender
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
logger: "loguru.Logger" = loguru.logger
@@ -31,6 +31,8 @@ def entrypoint(
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
@@ -42,8 +44,10 @@ def entrypoint(
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")
try:
event_sender.close()
task_receiver.close()
finally:
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")

View File

@@ -1,5 +1,7 @@
import time
import mlx.core as mx
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
@@ -36,12 +38,11 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_cleanup,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
@@ -57,182 +58,153 @@ def main(
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
try:
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(shard_metadata, "should_timeout", 0):
time.sleep(timeout)
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(shard_metadata, "should_timeout", 0):
time.sleep(timeout)
setup_start_time = time.time()
setup_start_time = time.time()
model = None
tokenizer = None
sampler = None
group = None
model = None
tokenizer = None
group = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected)
and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
assert sampler
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
sampler=sampler,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert model
assert tokenizer
assert sampler
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
sampler=sampler,
task=task_params,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
mlx_cleanup(model, tokenizer, group)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if isinstance(current_status, RunnerShutdown):
break
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {runner_id} crashed with critical exception {e}"
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(error_message=str(e)),
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
model, tokenizer = load_mlx_items(bound_instance, group)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"

View File

@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)

View File

@@ -32,6 +32,8 @@ async def check_reachability(
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
return None
finally:
connection.close()

246
tests/headless_runner.py Normal file
View File

@@ -0,0 +1,246 @@
import multiprocessing as mp
import socket
import time
import typing
import anyio
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
)
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, mp_channel
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
hn = socket.gethostname()
mp.set_start_method("spawn", force=True)
logger_setup(None)
async def main():
logger.info("starting cool server majig")
logger.info(hn)
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
shutdown = anyio.Event()
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: shutdown.wait(),
)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["llama-3.2-1b"].model_id))
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, ring_instance(test, iid))
def ring_instance(test: Tests, iid: InstanceId) -> Instance:
global hn
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if hn.startswith(test.devs[i][0]):
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
if i + 1 < len(test.devs):
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
meta = MODEL_CARDS[test.model_id].metadata
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=(meta.n_layers // world_size) * i,
end_layer=min(
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
),
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
)
for i in range(world_size)
},
),
)
return instance
async def execute_test(test: Tests, instance: Instance):
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, jaccl_instance(test, iid))
def jaccl_instance(test: Tests, iid: InstanceId):
global hn
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
for name, _ in test.devs:
if hn.startswith(name):
hn = name
break
return MlxJacclInstance(
instance_id=iid,
ibv_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
)
for i in range(world_size)
},
),
)
def new_runner(
instance: Instance,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

52
tests/start_distributed_test.sh Executable file
View File

@@ -0,0 +1,52 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"llama-3.2-1b\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed"
} &
done
wait