mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 13:29:29 -05:00
Compare commits
15 Commits
alexcheema
...
prioritise
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5de67883c0 | ||
|
|
6dbbe7797b | ||
|
|
9357503c6f | ||
|
|
ba19940828 | ||
|
|
f255345a1a | ||
|
|
a1939c89f2 | ||
|
|
cb9c9ee55c | ||
|
|
df240f834d | ||
|
|
cd125b3b8c | ||
|
|
b783a21399 | ||
|
|
43f12f5d08 | ||
|
|
8027d7933f | ||
|
|
ac6efa747b | ||
|
|
2e3c33db6d | ||
|
|
fc8e6ad06b |
@@ -276,23 +276,24 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
@@ -39,18 +39,12 @@ class StreamingDetokenizer:
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
tokens: list[int]
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""The full text decoded so far."""
|
||||
...
|
||||
@property
|
||||
def last_segment(self) -> str:
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
...
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
@@ -114,7 +108,6 @@ class TokenizerWrapper:
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
eos_token_ids: list[int] | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
|
||||
39
AGENTS.md
39
AGENTS.md
@@ -116,45 +116,6 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Model Storage
|
||||
|
||||
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
|
||||
|
||||
## Creating Model Instances via API
|
||||
|
||||
When testing with the API, you must first create a model instance before sending chat completions:
|
||||
|
||||
```bash
|
||||
# 1. Get instance previews for a model
|
||||
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
|
||||
|
||||
# 2. Create an instance from the first valid preview
|
||||
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
|
||||
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
|
||||
|
||||
# 3. Wait for the runner to become ready (check logs for "runner ready")
|
||||
|
||||
# 4. Send chat completions using the full model ID
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
|
||||
```
|
||||
|
||||
## Logs
|
||||
|
||||
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
### Distributed Testing
|
||||
|
||||
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
|
||||
|
||||
```bash
|
||||
# On each machine in the test cluster, use the same unique namespace
|
||||
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
|
||||
```
|
||||
|
||||
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.
|
||||
|
||||
@@ -14,6 +14,7 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@State private var showAllNodes = false
|
||||
@@ -24,6 +25,8 @@ struct ContentView: View {
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var uninstallInProgress = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
@State private var pendingHFToken: String = ""
|
||||
@State private var pendingEnableImageModels = false
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
@@ -303,6 +306,49 @@ struct ContentView: View {
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
}
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("HuggingFace Token")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
SecureField("optional", text: $pendingHFToken)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingHFToken = controller.hfToken
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.hfToken = pendingHFToken
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingHFToken == controller.hfToken)
|
||||
}
|
||||
}
|
||||
Divider()
|
||||
HStack {
|
||||
Toggle(
|
||||
"Enable Image Models (experimental)", isOn: $pendingEnableImageModels
|
||||
)
|
||||
.toggleStyle(.switch)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingEnableImageModels = controller.enableImageModels
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
Button("Save & Restart") {
|
||||
controller.enableImageModels = pendingEnableImageModels
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingEnableImageModels == controller.enableImageModels)
|
||||
}
|
||||
HoverButton(title: "Check for Updates", small: true) {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
@@ -423,6 +469,44 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
/// Shows TB bridge status for all nodes from exo cluster state
|
||||
private var clusterThunderboltBridgeView: some View {
|
||||
let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]
|
||||
let localNodeId = stateService.localNodeId
|
||||
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
|
||||
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
if bridgeStatuses.isEmpty {
|
||||
Text("Cluster TB Bridge: No data")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
} else {
|
||||
Text("Cluster TB Bridge Status:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(Array(bridgeStatuses.keys.sorted()), id: \.self) { nodeId in
|
||||
if let status = bridgeStatuses[nodeId] {
|
||||
let nodeName =
|
||||
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
|
||||
let isLocal = nodeId == localNodeId
|
||||
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
|
||||
let statusText =
|
||||
!status.exists
|
||||
? "N/A"
|
||||
: (status.enabled ? "Enabled" : "Disabled")
|
||||
let color: Color =
|
||||
!status.exists
|
||||
? .secondary
|
||||
: (status.enabled ? .red : .green)
|
||||
Text("\(prefix) \(statusText)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(color)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var interfaceIpList: some View {
|
||||
let statuses = networkStatusService.status.interfaceStatuses
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
@@ -465,6 +549,7 @@ struct ContentView: View {
|
||||
Text(thunderboltStatusText)
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
clusterThunderboltBridgeView
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
|
||||
@@ -21,6 +21,7 @@ struct EXOApp: App {
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
|
||||
@@ -41,10 +42,13 @@ struct EXOApp: App {
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
|
||||
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
// Remove old LaunchDaemon components if they exist (from previous versions)
|
||||
cleanupLegacyNetworkSetup()
|
||||
// Check local network access periodically (warning disappears when user grants permission)
|
||||
localNetwork.startPeriodicChecking(interval: 10)
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -58,6 +62,7 @@ struct EXOApp: App {
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
.environmentObject(thunderboltBridgeService)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
}
|
||||
@@ -130,6 +135,37 @@ struct EXOApp: App {
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
private func cleanupLegacyNetworkSetup() {
|
||||
guard NetworkSetupHelper.hasInstalledComponents() else { return }
|
||||
// Dispatch async to ensure app is ready before showing alert
|
||||
DispatchQueue.main.async {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to configure local network discovery on your device. This requires granting permission once."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Continue")
|
||||
alert.addButton(withTitle: "Later")
|
||||
|
||||
let response = alert.runModal()
|
||||
guard response == .alertFirstButtonReturn else {
|
||||
Logger().info("User deferred legacy network setup cleanup")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try NetworkSetupHelper.uninstall()
|
||||
Logger().info("Cleaned up legacy network setup components")
|
||||
} catch {
|
||||
// Non-fatal: user may have cancelled admin prompt or cleanup may have
|
||||
// partially succeeded. The app will continue normally.
|
||||
Logger().warning(
|
||||
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
|
||||
@@ -3,6 +3,8 @@ import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
private let hfTokenKey = "EXOHFToken"
|
||||
private let enableImageModelsKey = "EXOEnableImageModels"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
@@ -37,6 +39,22 @@ final class ExoProcessController: ObservableObject {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
@Published var hfToken: String = {
|
||||
return UserDefaults.standard.string(forKey: hfTokenKey) ?? ""
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(hfToken, forKey: hfTokenKey)
|
||||
}
|
||||
}
|
||||
@Published var enableImageModels: Bool = {
|
||||
return UserDefaults.standard.bool(forKey: enableImageModelsKey)
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -191,6 +209,12 @@ final class ExoProcessController: ObservableObject {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
if !hfToken.isEmpty {
|
||||
environment["HF_TOKEN"] = hfToken
|
||||
}
|
||||
if enableImageModels {
|
||||
environment["EXO_ENABLE_IMAGE_MODELS"] = "true"
|
||||
}
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
|
||||
@@ -5,17 +5,43 @@ import Foundation
|
||||
struct ClusterState: Decodable {
|
||||
let instances: [String: ClusterInstance]
|
||||
let runners: [String: RunnerStatusSummary]
|
||||
let nodeProfiles: [String: NodeProfile]
|
||||
let tasks: [String: ClusterTask]
|
||||
let topology: Topology?
|
||||
let downloads: [String: [NodeDownloadStatus]]
|
||||
let thunderboltBridgeCycles: [[String]]
|
||||
|
||||
// Granular node state (split from the old nodeProfiles)
|
||||
let nodeIdentities: [String: NodeIdentity]
|
||||
let nodeMemory: [String: MemoryInfo]
|
||||
let nodeSystem: [String: SystemInfo]
|
||||
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
|
||||
|
||||
/// Computed property for backwards compatibility - merges granular state into NodeProfile
|
||||
var nodeProfiles: [String: NodeProfile] {
|
||||
var profiles: [String: NodeProfile] = [:]
|
||||
let allNodeIds = Set(nodeIdentities.keys)
|
||||
.union(nodeMemory.keys)
|
||||
.union(nodeSystem.keys)
|
||||
for nodeId in allNodeIds {
|
||||
let identity = nodeIdentities[nodeId]
|
||||
let memory = nodeMemory[nodeId]
|
||||
let system = nodeSystem[nodeId]
|
||||
profiles[nodeId] = NodeProfile(
|
||||
modelId: identity?.modelId,
|
||||
chipId: identity?.chipId,
|
||||
friendlyName: identity?.friendlyName,
|
||||
memory: memory,
|
||||
system: system
|
||||
)
|
||||
}
|
||||
return profiles
|
||||
}
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
@@ -24,15 +50,34 @@ struct ClusterState: Decodable {
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
self.thunderboltBridgeCycles =
|
||||
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
|
||||
|
||||
// Granular node state
|
||||
self.nodeIdentities =
|
||||
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
|
||||
?? [:]
|
||||
self.nodeMemory =
|
||||
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
|
||||
self.nodeSystem =
|
||||
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
|
||||
self.nodeThunderboltBridge =
|
||||
try container.decodeIfPresent(
|
||||
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
|
||||
) ?? [:]
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case instances
|
||||
case runners
|
||||
case nodeProfiles
|
||||
case topology
|
||||
case tasks
|
||||
case downloads
|
||||
case thunderboltBridgeCycles
|
||||
case nodeIdentities
|
||||
case nodeMemory
|
||||
case nodeSystem
|
||||
case nodeThunderboltBridge
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +147,18 @@ struct NodeProfile: Decodable {
|
||||
let system: SystemInfo?
|
||||
}
|
||||
|
||||
struct NodeIdentity: Decodable {
|
||||
let modelId: String?
|
||||
let chipId: String?
|
||||
let friendlyName: String?
|
||||
}
|
||||
|
||||
struct ThunderboltBridgeStatus: Decodable {
|
||||
let enabled: Bool
|
||||
let exists: Bool
|
||||
let serviceName: String?
|
||||
}
|
||||
|
||||
struct MemoryInfo: Decodable {
|
||||
let ramTotal: MemoryValue?
|
||||
let ramAvailable: MemoryValue?
|
||||
@@ -120,16 +177,51 @@ struct SystemInfo: Decodable {
|
||||
}
|
||||
|
||||
struct Topology: Decodable {
|
||||
let nodes: [TopologyNode]
|
||||
let connections: [TopologyConnection]?
|
||||
/// Node IDs in the topology
|
||||
let nodes: [String]
|
||||
/// Flattened list of connections (source -> sink pairs)
|
||||
let connections: [TopologyConnection]
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
|
||||
|
||||
// Connections come as nested map: { source: { sink: [edges] } }
|
||||
// We flatten to array of (source, sink) pairs
|
||||
var flatConnections: [TopologyConnection] = []
|
||||
if let nested = try container.decodeIfPresent(
|
||||
[String: [String: [AnyCodable]]].self, forKey: .connections
|
||||
) {
|
||||
for (source, sinks) in nested {
|
||||
for sink in sinks.keys {
|
||||
flatConnections.append(
|
||||
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
|
||||
}
|
||||
}
|
||||
}
|
||||
self.connections = flatConnections
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case nodes
|
||||
case connections
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyNode: Decodable {
|
||||
let nodeId: String
|
||||
let nodeProfile: NodeProfile
|
||||
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
|
||||
private struct AnyCodable: Decodable {
|
||||
init(from decoder: Decoder) throws {
|
||||
// Just consume the value without storing it
|
||||
_ = try? decoder.singleValueContainer().decode(Bool.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Int.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Double.self)
|
||||
_ = try? decoder.singleValueContainer().decode(String.self)
|
||||
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
|
||||
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyConnection: Decodable {
|
||||
struct TopologyConnection {
|
||||
let localNodeId: String
|
||||
let sendBackNodeId: String
|
||||
}
|
||||
|
||||
@@ -55,12 +55,16 @@ struct BugReportService {
|
||||
let stateData = try await stateResult
|
||||
let eventsData = try await eventsResult
|
||||
|
||||
// Extract cluster TB bridge status from exo state
|
||||
let clusterTbBridgeStatus = extractClusterTbBridgeStatus(from: stateData)
|
||||
|
||||
let reportJSON = makeReportJson(
|
||||
timestamp: timestamp,
|
||||
hostName: hostName,
|
||||
ifconfig: ifconfigText,
|
||||
debugInfo: debugInfo,
|
||||
isManual: isManual
|
||||
isManual: isManual,
|
||||
clusterTbBridgeStatus: clusterTbBridgeStatus
|
||||
)
|
||||
|
||||
let uploads: [(path: String, data: Data?)] = [
|
||||
@@ -178,18 +182,19 @@ struct BugReportService {
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeDisabled() -> Bool? {
|
||||
let result = runCommand([
|
||||
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
|
||||
])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
return false
|
||||
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
|
||||
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
|
||||
// No bridge containing Thunderbolt interfaces exists
|
||||
return nil
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return true
|
||||
|
||||
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
// Return true if disabled, false if enabled
|
||||
return !isEnabled
|
||||
}
|
||||
|
||||
private func readInterfaces() -> [DebugInfo.InterfaceStatus] {
|
||||
@@ -268,11 +273,12 @@ struct BugReportService {
|
||||
hostName: String,
|
||||
ifconfig: String,
|
||||
debugInfo: DebugInfo,
|
||||
isManual: Bool
|
||||
isManual: Bool,
|
||||
clusterTbBridgeStatus: [[String: Any]]?
|
||||
) -> Data? {
|
||||
let system = readSystemMetadata()
|
||||
let exo = readExoMetadata()
|
||||
let payload: [String: Any] = [
|
||||
var payload: [String: Any] = [
|
||||
"timestamp": timestamp,
|
||||
"host": hostName,
|
||||
"ifconfig": ifconfig,
|
||||
@@ -282,9 +288,38 @@ struct BugReportService {
|
||||
"exo_commit": exo.commit as Any,
|
||||
"report_type": isManual ? "manual" : "automated",
|
||||
]
|
||||
if let tbStatus = clusterTbBridgeStatus {
|
||||
payload["cluster_thunderbolt_bridge"] = tbStatus
|
||||
}
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
|
||||
/// Extracts cluster-wide Thunderbolt Bridge status from exo state JSON
|
||||
private func extractClusterTbBridgeStatus(from stateData: Data?) -> [[String: Any]]? {
|
||||
guard let data = stateData,
|
||||
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||
let nodeThunderboltBridge = json["node_thunderbolt_bridge"] as? [String: [String: Any]]
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result: [[String: Any]] = []
|
||||
for (nodeId, status) in nodeThunderboltBridge {
|
||||
var entry: [String: Any] = ["node_id": nodeId]
|
||||
if let enabled = status["enabled"] as? Bool {
|
||||
entry["enabled"] = enabled
|
||||
}
|
||||
if let exists = status["exists"] as? Bool {
|
||||
entry["exists"] = exists
|
||||
}
|
||||
if let serviceName = status["service_name"] as? String {
|
||||
entry["service_name"] = serviceName
|
||||
}
|
||||
result.append(entry)
|
||||
}
|
||||
return result.isEmpty ? nil : result
|
||||
}
|
||||
|
||||
private func readSystemMetadata() -> [String: Any] {
|
||||
let hostname = safeRunCommand(["/bin/hostname"])
|
||||
let computerName = safeRunCommand(["/usr/sbin/scutil", "--get", "ComputerName"])
|
||||
|
||||
@@ -41,6 +41,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
private var periodicTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
@@ -48,10 +49,39 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
/// Checks if local network access is working (one-time check).
|
||||
func check() {
|
||||
performCheck()
|
||||
}
|
||||
|
||||
/// Starts periodic checking of local network access.
|
||||
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
|
||||
func startPeriodicChecking(interval: TimeInterval = 10) {
|
||||
stopPeriodicChecking()
|
||||
// Do an immediate check first
|
||||
performCheck()
|
||||
// Then schedule periodic checks
|
||||
periodicTask = Task { [weak self] in
|
||||
while !Task.isCancelled {
|
||||
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
|
||||
guard !Task.isCancelled else { break }
|
||||
self?.performCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops periodic checking.
|
||||
func stopPeriodicChecking() {
|
||||
periodicTask?.cancel()
|
||||
periodicTask = nil
|
||||
}
|
||||
|
||||
private func performCheck() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
// Only show "checking" status on first check to avoid UI flicker
|
||||
if status == .unknown {
|
||||
status = .checking
|
||||
}
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
@@ -60,12 +90,15 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
// Only log on state changes or first check to reduce noise
|
||||
if isFirstCheck || result != self.status {
|
||||
Self.logger.info("Local network check: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +174,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
func stop() {
|
||||
stopPeriodicChecking()
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
|
||||
@@ -7,48 +7,10 @@ enum NetworkSetupHelper {
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge.sh"
|
||||
// Legacy script path from older versions
|
||||
private static let legacyScriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
do {
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
}
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
@@ -63,8 +25,9 @@ enum NetworkSetupHelper {
|
||||
static func hasInstalledComponents() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
return scriptExists || plistExists
|
||||
return scriptExists || legacyScriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
@@ -73,6 +36,7 @@ enum NetworkSetupHelper {
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
@@ -83,8 +47,9 @@ enum NetworkSetupHelper {
|
||||
# Remove LaunchDaemon plist
|
||||
rm -f "$PLIST_DEST"
|
||||
|
||||
# Remove the script and parent directory if empty
|
||||
# Remove the script (current and legacy paths) and parent directory if empty
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rm -f "$LEGACY_SCRIPT_DEST"
|
||||
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
|
||||
|
||||
# Remove log files
|
||||
@@ -98,99 +63,42 @@ enum NetworkSetupHelper {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
} || true
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
# We find it dynamically by looking for bridges containing Thunderbolt interfaces
|
||||
find_and_enable_thunderbolt_bridge() {
|
||||
# Get Thunderbolt interface devices from hardware ports
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
')
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
if echo "$members" | grep -qx "$tb_dev"; then
|
||||
# Find the service name for this bridge device
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
')
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
|
||||
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let interval = plist["StartInterval"] as? Int,
|
||||
interval == requiredStartInterval
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private static func installLaunchDaemon() async throws {
|
||||
let installerScript = makeInstallerScript()
|
||||
try runShellAsAdmin(installerScript)
|
||||
}
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript =
|
||||
script
|
||||
|
||||
@@ -153,22 +153,18 @@ private struct NetworkStatusFetcher {
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeState() -> ThunderboltState? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
guard result.exitCode == 0 else {
|
||||
let lower = result.output.lowercased() + result.error.lowercased()
|
||||
if lower.contains("not a recognized network service") {
|
||||
return .deleted
|
||||
}
|
||||
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
|
||||
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
|
||||
// No bridge containing Thunderbolt interfaces exists
|
||||
return .deleted
|
||||
}
|
||||
|
||||
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
return .enabled
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return .disabled
|
||||
}
|
||||
return nil
|
||||
|
||||
return isEnabled ? .enabled : .disabled
|
||||
}
|
||||
|
||||
private func readBridgeInactive() -> Bool? {
|
||||
|
||||
194
app/EXO/EXO/Services/ThunderboltBridgeDetector.swift
Normal file
194
app/EXO/EXO/Services/ThunderboltBridgeDetector.swift
Normal file
@@ -0,0 +1,194 @@
|
||||
import Foundation
|
||||
import os.log
|
||||
|
||||
/// Utility for dynamically detecting Thunderbolt Bridge network services.
|
||||
/// This mirrors the Python logic in info_gatherer.py - we never assume the service
|
||||
/// is named "Thunderbolt Bridge", instead we find bridges containing Thunderbolt interfaces.
|
||||
enum ThunderboltBridgeDetector {
|
||||
private static let logger = Logger(
|
||||
subsystem: "io.exo.EXO", category: "ThunderboltBridgeDetector")
|
||||
|
||||
struct CommandResult {
|
||||
let exitCode: Int32
|
||||
let output: String
|
||||
let error: String
|
||||
}
|
||||
|
||||
/// Find the network service name of a bridge containing Thunderbolt interfaces.
|
||||
/// Returns nil if no such bridge exists.
|
||||
static func findThunderboltBridgeServiceName() -> String? {
|
||||
// 1. Get all Thunderbolt interface devices (e.g., en2, en3)
|
||||
guard let thunderboltDevices = getThunderboltDevices(), !thunderboltDevices.isEmpty else {
|
||||
logger.debug("No Thunderbolt devices found")
|
||||
return nil
|
||||
}
|
||||
logger.debug("Found Thunderbolt devices: \(thunderboltDevices.joined(separator: ", "))")
|
||||
|
||||
// 2. Get bridge services from network service order
|
||||
guard let bridgeServices = getBridgeServices(), !bridgeServices.isEmpty else {
|
||||
logger.debug("No bridge services found")
|
||||
return nil
|
||||
}
|
||||
logger.debug("Found bridge services: \(bridgeServices.keys.joined(separator: ", "))")
|
||||
|
||||
// 3. Find a bridge that contains Thunderbolt interfaces
|
||||
for (bridgeDevice, serviceName) in bridgeServices {
|
||||
let members = getBridgeMembers(bridgeDevice: bridgeDevice)
|
||||
logger.debug(
|
||||
"Bridge \(bridgeDevice) (\(serviceName)) has members: \(members.joined(separator: ", "))"
|
||||
)
|
||||
|
||||
// Check if any Thunderbolt device is a member of this bridge
|
||||
if !members.isDisjoint(with: thunderboltDevices) {
|
||||
logger.info(
|
||||
"Found Thunderbolt Bridge service: '\(serviceName)' (device: \(bridgeDevice))")
|
||||
return serviceName
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("No bridge found containing Thunderbolt interfaces")
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
|
||||
private static func getThunderboltDevices() -> Set<String>? {
|
||||
let result = runCommand(["networksetup", "-listallhardwareports"])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("networksetup -listallhardwareports failed: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
var thunderboltDevices: Set<String> = []
|
||||
var currentPort: String?
|
||||
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("Hardware Port:") {
|
||||
currentPort = String(trimmed.dropFirst("Hardware Port:".count)).trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("Device:"), let port = currentPort {
|
||||
let device = String(trimmed.dropFirst("Device:".count)).trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
if port.lowercased().contains("thunderbolt") {
|
||||
thunderboltDevices.insert(device)
|
||||
}
|
||||
currentPort = nil
|
||||
}
|
||||
}
|
||||
|
||||
return thunderboltDevices
|
||||
}
|
||||
|
||||
/// Get mapping of bridge device -> service name from network service order.
|
||||
private static func getBridgeServices() -> [String: String]? {
|
||||
let result = runCommand(["networksetup", "-listnetworkserviceorder"])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("networksetup -listnetworkserviceorder failed: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse service order to find bridge devices and their service names
|
||||
// Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
|
||||
var bridgeServices: [String: String] = [:]
|
||||
var currentService: String?
|
||||
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
// Match "(N) Service Name" or "(*) Service Name" (disabled)
|
||||
// but NOT "(Hardware Port: ...)" lines
|
||||
if trimmed.hasPrefix("("), trimmed.contains(")"),
|
||||
!trimmed.hasPrefix("(Hardware Port:")
|
||||
{
|
||||
if let parenEnd = trimmed.firstIndex(of: ")") {
|
||||
let afterParen = trimmed.index(after: parenEnd)
|
||||
if afterParen < trimmed.endIndex {
|
||||
currentService =
|
||||
String(trimmed[afterParen...])
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Match "(Hardware Port: ..., Device: bridgeX)"
|
||||
else if let service = currentService, trimmed.contains("Device: bridge") {
|
||||
// Extract device name from "..., Device: bridge0)"
|
||||
if let deviceRange = trimmed.range(of: "Device: ") {
|
||||
let afterDevice = trimmed[deviceRange.upperBound...]
|
||||
if let parenIndex = afterDevice.firstIndex(of: ")") {
|
||||
let device = String(afterDevice[..<parenIndex])
|
||||
bridgeServices[device] = service
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bridgeServices
|
||||
}
|
||||
|
||||
/// Get member interfaces of a bridge device via ifconfig.
|
||||
private static func getBridgeMembers(bridgeDevice: String) -> Set<String> {
|
||||
let result = runCommand(["ifconfig", bridgeDevice])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.debug("ifconfig \(bridgeDevice) failed")
|
||||
return []
|
||||
}
|
||||
|
||||
var members: Set<String> = []
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("member:") {
|
||||
let parts = trimmed.split(separator: " ")
|
||||
if parts.count > 1 {
|
||||
members.insert(String(parts[1]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return members
|
||||
}
|
||||
|
||||
/// Check if a network service is enabled.
|
||||
static func isServiceEnabled(serviceName: String) -> Bool? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", serviceName])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("Failed to check if '\(serviceName)' is enabled: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private static func runCommand(_ arguments: [String]) -> CommandResult {
|
||||
let process = Process()
|
||||
process.launchPath = "/usr/bin/env"
|
||||
process.arguments = arguments
|
||||
|
||||
let stdout = Pipe()
|
||||
let stderr = Pipe()
|
||||
process.standardOutput = stdout
|
||||
process.standardError = stderr
|
||||
|
||||
do {
|
||||
try process.run()
|
||||
} catch {
|
||||
return CommandResult(exitCode: -1, output: "", error: error.localizedDescription)
|
||||
}
|
||||
process.waitUntilExit()
|
||||
|
||||
let outputData = stdout.fileHandleForReading.readDataToEndOfFile()
|
||||
let errorData = stderr.fileHandleForReading.readDataToEndOfFile()
|
||||
|
||||
return CommandResult(
|
||||
exitCode: process.terminationStatus,
|
||||
output: String(decoding: outputData, as: UTF8.self),
|
||||
error: String(decoding: errorData, as: UTF8.self)
|
||||
)
|
||||
}
|
||||
}
|
||||
258
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
258
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
@@ -0,0 +1,258 @@
|
||||
import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
import Security
|
||||
import SystemConfiguration
|
||||
import os.log
|
||||
|
||||
@MainActor
|
||||
final class ThunderboltBridgeService: ObservableObject {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
|
||||
|
||||
@Published private(set) var detectedCycle: [String]?
|
||||
@Published private(set) var hasPromptedForCurrentCycle = false
|
||||
@Published private(set) var lastError: String?
|
||||
|
||||
private weak var clusterStateService: ClusterStateService?
|
||||
private var cancellables = Set<AnyCancellable>()
|
||||
private var previousCycleSignature: String?
|
||||
|
||||
init(clusterStateService: ClusterStateService) {
|
||||
self.clusterStateService = clusterStateService
|
||||
setupObserver()
|
||||
}
|
||||
|
||||
private func setupObserver() {
|
||||
guard let service = clusterStateService else { return }
|
||||
|
||||
service.$latestSnapshot
|
||||
.compactMap { $0 }
|
||||
.sink { [weak self] snapshot in
|
||||
self?.checkForCycles(snapshot: snapshot)
|
||||
}
|
||||
.store(in: &cancellables)
|
||||
}
|
||||
|
||||
private func checkForCycles(snapshot: ClusterState) {
|
||||
let cycles = snapshot.thunderboltBridgeCycles
|
||||
|
||||
// Only consider cycles with more than 2 nodes
|
||||
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
|
||||
// No problematic cycles detected, reset state
|
||||
if detectedCycle != nil {
|
||||
detectedCycle = nil
|
||||
previousCycleSignature = nil
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Create a signature for this cycle to detect if it changed
|
||||
let cycleSignature = firstCycle.sorted().joined(separator: ",")
|
||||
|
||||
// If this is a new/different cycle, reset the prompt state
|
||||
if cycleSignature != previousCycleSignature {
|
||||
previousCycleSignature = cycleSignature
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
|
||||
detectedCycle = firstCycle
|
||||
|
||||
// Only prompt once per cycle
|
||||
if !hasPromptedForCurrentCycle {
|
||||
showDisableBridgePrompt(nodeIds: firstCycle)
|
||||
}
|
||||
}
|
||||
|
||||
private func showDisableBridgePrompt(nodeIds: [String]) {
|
||||
hasPromptedForCurrentCycle = true
|
||||
|
||||
// Get friendly names for the nodes if available
|
||||
let nodeNames = nodeIds.map { nodeId -> String in
|
||||
if let snapshot = clusterStateService?.latestSnapshot,
|
||||
let profile = snapshot.nodeProfiles[nodeId],
|
||||
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
|
||||
{
|
||||
return friendlyName
|
||||
}
|
||||
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
|
||||
}
|
||||
let machineNames = nodeNames.joined(separator: ", ")
|
||||
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Thunderbolt Bridge Loop Detected"
|
||||
alert.informativeText = """
|
||||
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
|
||||
|
||||
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Disable Bridge")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
|
||||
let response = alert.runModal()
|
||||
|
||||
if response == .alertFirstButtonReturn {
|
||||
Task {
|
||||
await disableThunderboltBridge()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func disableThunderboltBridge() async {
|
||||
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
|
||||
lastError = nil
|
||||
|
||||
do {
|
||||
try await disableThunderboltBridgeWithSCPreferences()
|
||||
Self.logger.info("Successfully disabled Thunderbolt Bridge")
|
||||
} catch {
|
||||
Self.logger.error(
|
||||
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
lastError = error.localizedDescription
|
||||
showErrorAlert(message: error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
private func disableThunderboltBridgeWithSCPreferences() async throws {
|
||||
// 1. Create authorization reference
|
||||
var authRef: AuthorizationRef?
|
||||
var status = AuthorizationCreate(nil, nil, [], &authRef)
|
||||
guard status == errAuthorizationSuccess, let authRef = authRef else {
|
||||
throw ThunderboltBridgeError.authorizationFailed
|
||||
}
|
||||
|
||||
defer { AuthorizationFree(authRef, [.destroyRights]) }
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
}
|
||||
throw ThunderboltBridgeError.authorizationDenied
|
||||
}
|
||||
|
||||
// 3. Create SCPreferences with authorization
|
||||
guard
|
||||
let prefs = SCPreferencesCreateWithAuthorization(
|
||||
kCFAllocatorDefault,
|
||||
"EXO" as CFString,
|
||||
nil,
|
||||
authRef
|
||||
)
|
||||
else {
|
||||
throw ThunderboltBridgeError.preferencesCreationFailed
|
||||
}
|
||||
|
||||
// 4. Lock, modify, commit
|
||||
guard SCPreferencesLock(prefs, true) else {
|
||||
throw ThunderboltBridgeError.lockFailed
|
||||
}
|
||||
|
||||
defer {
|
||||
SCPreferencesUnlock(prefs)
|
||||
}
|
||||
|
||||
// 5. Find the Thunderbolt Bridge service dynamically (don't assume the name)
|
||||
guard let targetServiceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName()
|
||||
else {
|
||||
throw ThunderboltBridgeError.serviceNotFound
|
||||
}
|
||||
|
||||
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
|
||||
throw ThunderboltBridgeError.servicesNotFound
|
||||
}
|
||||
|
||||
var found = false
|
||||
for service in allServices {
|
||||
if let name = SCNetworkServiceGetName(service) as String?,
|
||||
name == targetServiceName
|
||||
{
|
||||
guard SCNetworkServiceSetEnabled(service, false) else {
|
||||
throw ThunderboltBridgeError.disableFailed
|
||||
}
|
||||
found = true
|
||||
Self.logger.info(
|
||||
"Found and disabled Thunderbolt Bridge service: '\(targetServiceName)'")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
throw ThunderboltBridgeError.serviceNotFound
|
||||
}
|
||||
|
||||
// 6. Commit and apply
|
||||
guard SCPreferencesCommitChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.commitFailed
|
||||
}
|
||||
|
||||
guard SCPreferencesApplyChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.applyFailed
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Failed to Disable Thunderbolt Bridge"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
}
|
||||
|
||||
enum ThunderboltBridgeError: LocalizedError {
|
||||
case authorizationFailed
|
||||
case authorizationCanceled
|
||||
case authorizationDenied
|
||||
case preferencesCreationFailed
|
||||
case lockFailed
|
||||
case servicesNotFound
|
||||
case serviceNotFound
|
||||
case disableFailed
|
||||
case commitFailed
|
||||
case applyFailed
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .authorizationFailed:
|
||||
return "Failed to create authorization"
|
||||
case .authorizationCanceled:
|
||||
return "Authorization was canceled by user"
|
||||
case .authorizationDenied:
|
||||
return "Authorization was denied"
|
||||
case .preferencesCreationFailed:
|
||||
return "Failed to access network preferences"
|
||||
case .lockFailed:
|
||||
return "Failed to lock network preferences for modification"
|
||||
case .servicesNotFound:
|
||||
return "Could not retrieve network services"
|
||||
case .serviceNotFound:
|
||||
return "Thunderbolt Bridge service not found"
|
||||
case .disableFailed:
|
||||
return "Failed to disable Thunderbolt Bridge service"
|
||||
case .commitFailed:
|
||||
return "Failed to save network configuration changes"
|
||||
case .applyFailed:
|
||||
return "Failed to apply network configuration changes"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ struct TopologyViewModel {
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let topologyNodeIds = Set(topology?.nodes ?? [])
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
@@ -95,8 +95,8 @@ extension ClusterState {
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
var orderedNodes: [NodeViewModel] = []
|
||||
if let topologyNodes = topology?.nodes {
|
||||
for topoNode in topologyNodes {
|
||||
if let viewModel = nodesById[topoNode.nodeId] {
|
||||
for nodeId in topologyNodes {
|
||||
if let viewModel = nodesById[nodeId] {
|
||||
orderedNodes.append(viewModel)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ extension ClusterState {
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
topology?.connections.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,6 +865,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -904,6 +905,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1520,6 +1522,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1529,6 +1532,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,6 +1945,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2648,6 +2653,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2690,6 +2696,7 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2862,6 +2869,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3006,6 +3014,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3027,6 +3036,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -3,6 +3,45 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${inputs.self}/dashboard/package.json $out/
|
||||
cp ${inputs.self}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
touch $out/src/app.html
|
||||
'';
|
||||
|
||||
# Deps-only build using stub source (for prettier-svelte)
|
||||
# Only rebuilds when package.json or package-lock.json change
|
||||
dashboardDeps = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
|
||||
}
|
||||
# Override build phases to skip the actual build - just need node_modules
|
||||
{
|
||||
mkDerivation = {
|
||||
buildPhase = lib.mkForce "true";
|
||||
installPhase = lib.mkForce ''
|
||||
runHook preInstall
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
}
|
||||
];
|
||||
};
|
||||
|
||||
# Filter source to only include dashboard directory
|
||||
dashboardSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
@@ -42,11 +81,12 @@
|
||||
'';
|
||||
|
||||
# Prettier with svelte plugin for treefmt
|
||||
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
|
||||
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
|
||||
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
|
||||
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
|
||||
exec ${pkgs.nodejs}/bin/node \
|
||||
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
"$@"
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -89,7 +89,10 @@
|
||||
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsTextToImage(currentModel);
|
||||
return (
|
||||
modelSupportsTextToImage(currentModel) ||
|
||||
modelSupportsImageEditing(currentModel)
|
||||
);
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
@@ -646,6 +649,23 @@
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
|
||||
@@ -110,6 +110,36 @@
|
||||
setImageGenerationParams({ negativePrompt: value || null });
|
||||
}
|
||||
|
||||
function handleNumImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ numImages: 1 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 1) {
|
||||
setImageGenerationParams({ numImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStreamChange(enabled: boolean) {
|
||||
setImageGenerationParams({ stream: enabled });
|
||||
}
|
||||
|
||||
function handlePartialImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ partialImages: 0 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 0) {
|
||||
setImageGenerationParams({ partialImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function clearSteps() {
|
||||
setImageGenerationParams({ numInferenceSteps: null });
|
||||
}
|
||||
@@ -134,90 +164,92 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -325,6 +357,59 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Number of Images (not in edit mode) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>IMAGES:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
value={params.numImages}
|
||||
oninput={handleNumImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Stream toggle -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>STREAM:</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => handleStreamChange(!params.stream)}
|
||||
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
|
||||
? 'bg-exo-yellow'
|
||||
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
|
||||
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
|
||||
>
|
||||
<div
|
||||
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
|
||||
? 'right-0.5 bg-exo-black'
|
||||
: 'left-0.5 bg-exo-light-gray'}"
|
||||
></div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Partial Images (only when streaming) -->
|
||||
{#if params.stream}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>PARTIALS:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={params.partialImages}
|
||||
oninput={handlePartialImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Input Fidelity (edit mode only) -->
|
||||
{#if isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
|
||||
@@ -5,22 +5,32 @@
|
||||
topologyData,
|
||||
isTopologyMinimized,
|
||||
debugMode,
|
||||
nodeThunderboltBridge,
|
||||
type NodeInfo,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
highlightedNodes?: Set<string>;
|
||||
filteredNodes?: Set<string>;
|
||||
onNodeClick?: (nodeId: string) => void;
|
||||
}
|
||||
|
||||
let { class: className = "", highlightedNodes = new Set() }: Props = $props();
|
||||
let {
|
||||
class: className = "",
|
||||
highlightedNodes = new Set(),
|
||||
filteredNodes = new Set(),
|
||||
onNodeClick,
|
||||
}: Props = $props();
|
||||
|
||||
let svgContainer: SVGSVGElement | undefined = $state();
|
||||
let resizeObserver: ResizeObserver | undefined;
|
||||
let hoveredNodeId = $state<string | null>(null);
|
||||
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -522,10 +532,72 @@
|
||||
}
|
||||
}
|
||||
|
||||
let iconBaseWidth = nodeRadius * 1.2;
|
||||
let iconBaseHeight = nodeRadius * 1.0;
|
||||
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
|
||||
|
||||
const modelLower = modelId.toLowerCase();
|
||||
|
||||
// Check node states for styling
|
||||
const isHighlighted = highlightedNodes.has(nodeInfo.id);
|
||||
const isInFilter =
|
||||
filteredNodes.size > 0 && filteredNodes.has(nodeInfo.id);
|
||||
const isFilteredOut =
|
||||
filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id);
|
||||
const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter;
|
||||
|
||||
// Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out
|
||||
const wireColor = isInFilter
|
||||
? "rgba(255,215,0,1)" // Bright yellow for filter selection
|
||||
: isHovered
|
||||
? "rgba(255,215,0,0.7)" // Subtle yellow for hover
|
||||
: isHighlighted
|
||||
? "rgba(255,215,0,0.9)" // Yellow for instance highlight
|
||||
: isFilteredOut
|
||||
? "rgba(140,140,140,0.6)" // Grey for filtered out
|
||||
: "rgba(179,179,179,0.8)"; // Default
|
||||
const wireColorBright = "rgba(255,255,255,0.9)";
|
||||
const fillColor = isInFilter
|
||||
? "rgba(255,215,0,0.25)"
|
||||
: isHovered
|
||||
? "rgba(255,215,0,0.12)"
|
||||
: isHighlighted
|
||||
? "rgba(255,215,0,0.15)"
|
||||
: "rgba(255,215,0,0.08)";
|
||||
const strokeWidth = isInFilter
|
||||
? 3
|
||||
: isHovered
|
||||
? 2
|
||||
: isHighlighted
|
||||
? 2.5
|
||||
: 1.5;
|
||||
const screenFill = "rgba(0,20,40,0.9)";
|
||||
const glowColor = "rgba(255,215,0,0.3)";
|
||||
|
||||
const nodeG = nodesGroup
|
||||
.append("g")
|
||||
.attr("class", "graph-node")
|
||||
.style("cursor", "pointer");
|
||||
.style("cursor", onNodeClick ? "pointer" : "default")
|
||||
.style("opacity", isFilteredOut ? 0.5 : 1);
|
||||
|
||||
// Add click and hover handlers - hover just updates state, styling is applied during render
|
||||
nodeG
|
||||
.on("click", (event: MouseEvent) => {
|
||||
if (onNodeClick) {
|
||||
event.stopPropagation();
|
||||
onNodeClick(nodeInfo.id);
|
||||
}
|
||||
})
|
||||
.on("mouseenter", () => {
|
||||
if (onNodeClick) {
|
||||
hoveredNodeId = nodeInfo.id;
|
||||
}
|
||||
})
|
||||
.on("mouseleave", () => {
|
||||
if (hoveredNodeId === nodeInfo.id) {
|
||||
hoveredNodeId = null;
|
||||
}
|
||||
});
|
||||
|
||||
// Add tooltip
|
||||
nodeG
|
||||
@@ -534,27 +606,6 @@
|
||||
`${friendlyName}\nID: ${nodeInfo.id.slice(-8)}\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`,
|
||||
);
|
||||
|
||||
let iconBaseWidth = nodeRadius * 1.2;
|
||||
let iconBaseHeight = nodeRadius * 1.0;
|
||||
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
|
||||
|
||||
const modelLower = modelId.toLowerCase();
|
||||
|
||||
// Check if this node should be highlighted (from hovered instance)
|
||||
const isHighlighted = highlightedNodes.has(nodeInfo.id);
|
||||
|
||||
// Holographic wireframe colors - yellow border when highlighted
|
||||
const wireColor = isHighlighted
|
||||
? "rgba(255,215,0,0.9)"
|
||||
: "rgba(179,179,179,0.8)";
|
||||
const wireColorBright = "rgba(255,255,255,0.9)";
|
||||
const fillColor = isHighlighted
|
||||
? "rgba(255,215,0,0.15)"
|
||||
: "rgba(255,215,0,0.08)";
|
||||
const strokeWidth = isHighlighted ? 2.5 : 1.5;
|
||||
const screenFill = "rgba(0,20,40,0.9)";
|
||||
const glowColor = "rgba(255,215,0,0.3)";
|
||||
|
||||
if (modelLower === "mac studio") {
|
||||
// Mac Studio - classic cube with memory fill
|
||||
iconBaseWidth = nodeRadius * 1.25;
|
||||
@@ -579,6 +630,7 @@
|
||||
// Main body (uniform color)
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", x)
|
||||
.attr("y", y)
|
||||
.attr("width", iconBaseWidth)
|
||||
@@ -661,6 +713,7 @@
|
||||
// Main body (uniform color)
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", x)
|
||||
.attr("y", y)
|
||||
.attr("width", iconBaseWidth)
|
||||
@@ -738,6 +791,7 @@
|
||||
// Screen outer frame
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", screenX)
|
||||
.attr("y", y)
|
||||
.attr("width", screenWidth)
|
||||
@@ -846,6 +900,7 @@
|
||||
// Main shape
|
||||
nodeG
|
||||
.append("polygon")
|
||||
.attr("class", "node-outline")
|
||||
.attr("points", hexPoints)
|
||||
.attr("fill", fillColor)
|
||||
.attr("stroke", wireColor)
|
||||
@@ -1064,11 +1119,41 @@
|
||||
.attr("fill", "rgba(179,179,179,0.7)")
|
||||
.text(` (${ramUsagePercent.toFixed(0)}%)`);
|
||||
}
|
||||
|
||||
// Debug mode: Show TB bridge status
|
||||
if (debugEnabled) {
|
||||
const tbStatus = tbBridgeData[nodeInfo.id];
|
||||
if (tbStatus) {
|
||||
const tbY =
|
||||
nodeInfo.y +
|
||||
iconBaseHeight / 2 +
|
||||
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
|
||||
const tbFontSize = showFullLabels ? 9 : 7;
|
||||
const tbColor = tbStatus.enabled
|
||||
? "rgba(234,179,8,0.9)"
|
||||
: "rgba(100,100,100,0.7)";
|
||||
const tbText = tbStatus.enabled ? "TB:ON" : "TB:OFF";
|
||||
nodeG
|
||||
.append("text")
|
||||
.attr("x", nodeInfo.x)
|
||||
.attr("y", tbY)
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("fill", tbColor)
|
||||
.attr("font-size", tbFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(tbText);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (data) {
|
||||
// Track all reactive dependencies that affect rendering
|
||||
const _data = data;
|
||||
const _hoveredNodeId = hoveredNodeId;
|
||||
const _filteredNodes = filteredNodes;
|
||||
const _highlightedNodes = highlightedNodes;
|
||||
if (_data) {
|
||||
renderGraph();
|
||||
}
|
||||
});
|
||||
@@ -1091,12 +1176,8 @@
|
||||
|
||||
<style>
|
||||
:global(.graph-node) {
|
||||
transition:
|
||||
transform 0.2s ease,
|
||||
opacity 0.2s ease;
|
||||
}
|
||||
:global(.graph-node:hover) {
|
||||
filter: brightness(1.1);
|
||||
/* Only transition opacity for filtered-out nodes, no transition on hover stroke changes */
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
:global(.graph-link) {
|
||||
stroke: var(--exo-light-gray, #b3b3b3);
|
||||
|
||||
@@ -190,6 +190,13 @@ interface RawStateResponse {
|
||||
nodeMemory?: Record<string, RawMemoryUsage>;
|
||||
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
|
||||
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
|
||||
// Thunderbolt bridge status per node
|
||||
nodeThunderboltBridge?: Record<
|
||||
string,
|
||||
{ enabled: boolean; exists: boolean; serviceName?: string | null }
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -209,6 +216,8 @@ export interface Message {
|
||||
attachments?: MessageAttachment[];
|
||||
ttftMs?: number; // Time to first token in ms (for assistant messages)
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
requestType?: "chat" | "image-generation" | "image-editing";
|
||||
sourceImageDataUrl?: string; // For image editing regeneration
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -231,6 +240,10 @@ export interface ImageGenerationParams {
|
||||
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
// Streaming params
|
||||
stream: boolean;
|
||||
partialImages: number;
|
||||
// Advanced params
|
||||
seed: number | null;
|
||||
numInferenceSteps: number | null;
|
||||
@@ -250,6 +263,9 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "1024x1024",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
stream: true,
|
||||
partialImages: 3,
|
||||
seed: null,
|
||||
numInferenceSteps: null,
|
||||
guidance: null,
|
||||
@@ -419,7 +435,15 @@ class AppStore {
|
||||
placementPreviews = $state<PlacementPreview[]>([]);
|
||||
selectedPreviewModelId = $state<string | null>(null);
|
||||
isLoadingPreviews = $state(false);
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderboltBridge = $state<
|
||||
Record<
|
||||
string,
|
||||
{ enabled: boolean; exists: boolean; serviceName?: string | null }
|
||||
>
|
||||
>({});
|
||||
|
||||
// UI state
|
||||
isTopologyMinimized = $state(false);
|
||||
@@ -439,6 +463,7 @@ class AppStore {
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private previousNodeIds: Set<string> = new Set();
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -997,6 +1022,8 @@ class AppStore {
|
||||
nodeSystem: data.nodeSystem,
|
||||
nodeNetwork: data.nodeNetwork,
|
||||
});
|
||||
// Handle topology changes for preview filter
|
||||
this.handleTopologyChange();
|
||||
}
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
@@ -1008,6 +1035,10 @@ class AppStore {
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
this.nodeThunderboltBridge = data.nodeThunderboltBridge ?? {};
|
||||
this.lastUpdate = Date.now();
|
||||
} catch (error) {
|
||||
console.error("Error fetching state:", error);
|
||||
@@ -1023,9 +1054,14 @@ class AppStore {
|
||||
this.selectedPreviewModelId = modelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/instance/previews?model_id=${encodeURIComponent(modelId)}`,
|
||||
);
|
||||
let url = `/instance/previews?model_id=${encodeURIComponent(modelId)}`;
|
||||
// Add node filter if active
|
||||
if (this.previewNodeFilter.size > 0) {
|
||||
for (const nodeId of this.previewNodeFilter) {
|
||||
url += `&node_ids=${encodeURIComponent(nodeId)}`;
|
||||
}
|
||||
}
|
||||
const response = await fetch(url);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch placement previews: ${response.status}`,
|
||||
@@ -1075,6 +1111,71 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggle a node in the preview filter and re-fetch placements
|
||||
*/
|
||||
togglePreviewNodeFilter(nodeId: string) {
|
||||
const next = new Set(this.previewNodeFilter);
|
||||
if (next.has(nodeId)) {
|
||||
next.delete(nodeId);
|
||||
} else {
|
||||
next.add(nodeId);
|
||||
}
|
||||
this.previewNodeFilter = next;
|
||||
// Re-fetch with new filter if we have a selected model
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the preview node filter and re-fetch placements
|
||||
*/
|
||||
clearPreviewNodeFilter() {
|
||||
this.previewNodeFilter = new Set();
|
||||
// Re-fetch with no filter if we have a selected model
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle topology changes - clean up filter and re-fetch if needed
|
||||
*/
|
||||
private handleTopologyChange() {
|
||||
if (!this.topologyData) return;
|
||||
|
||||
const currentNodeIds = new Set(Object.keys(this.topologyData.nodes));
|
||||
|
||||
// Check if nodes have changed
|
||||
const nodesAdded = [...currentNodeIds].some(
|
||||
(id) => !this.previousNodeIds.has(id),
|
||||
);
|
||||
const nodesRemoved = [...this.previousNodeIds].some(
|
||||
(id) => !currentNodeIds.has(id),
|
||||
);
|
||||
|
||||
if (nodesAdded || nodesRemoved) {
|
||||
// Clean up filter - remove any nodes that no longer exist
|
||||
if (this.previewNodeFilter.size > 0) {
|
||||
const validFilterNodes = new Set(
|
||||
[...this.previewNodeFilter].filter((id) => currentNodeIds.has(id)),
|
||||
);
|
||||
if (validFilterNodes.size !== this.previewNodeFilter.size) {
|
||||
this.previewNodeFilter = validFilterNodes;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-fetch previews if we have a selected model (topology changed)
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
// Update tracked node IDs for next comparison
|
||||
this.previousNodeIds = currentNodeIds;
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a chat conversation - triggers the topology minimization animation
|
||||
* Creates a new conversation if none is active
|
||||
@@ -1171,10 +1272,46 @@ class AppStore {
|
||||
|
||||
if (lastUserIndex === -1) return;
|
||||
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
const lastUserMessage = this.messages[lastUserIndex];
|
||||
const requestType = lastUserMessage.requestType || "chat";
|
||||
const prompt = lastUserMessage.content;
|
||||
|
||||
// Resend the message to get a new response
|
||||
// Remove messages after user message (including the user message for image requests
|
||||
// since generateImage/editImage will re-add it)
|
||||
this.messages = this.messages.slice(0, lastUserIndex);
|
||||
|
||||
switch (requestType) {
|
||||
case "image-generation":
|
||||
await this.generateImage(prompt);
|
||||
break;
|
||||
case "image-editing":
|
||||
if (lastUserMessage.sourceImageDataUrl) {
|
||||
await this.editImage(prompt, lastUserMessage.sourceImageDataUrl);
|
||||
} else {
|
||||
// Can't regenerate edit without source image - restore user message and show error
|
||||
this.messages.push(lastUserMessage);
|
||||
const errorMessage = this.addMessage("assistant", "");
|
||||
const idx = this.messages.findIndex((m) => m.id === errorMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content =
|
||||
"Error: Cannot regenerate image edit - source image not found";
|
||||
}
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
break;
|
||||
case "chat":
|
||||
default:
|
||||
// Restore the user message for chat regeneration
|
||||
this.messages.push(lastUserMessage);
|
||||
await this.regenerateChatCompletion();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to regenerate a chat completion response
|
||||
*/
|
||||
private async regenerateChatCompletion(): Promise<void> {
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -1689,6 +1826,7 @@ class AppStore {
|
||||
role: "user",
|
||||
content: prompt,
|
||||
timestamp: Date.now(),
|
||||
requestType: "image-generation",
|
||||
};
|
||||
this.messages.push(userMessage);
|
||||
|
||||
@@ -1717,12 +1855,13 @@ class AppStore {
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
prompt,
|
||||
n: params.numImages,
|
||||
quality: params.quality,
|
||||
size: params.size,
|
||||
output_format: params.outputFormat,
|
||||
response_format: "b64_json",
|
||||
stream: true,
|
||||
partial_images: 3,
|
||||
stream: params.stream,
|
||||
partial_images: params.partialImages,
|
||||
};
|
||||
|
||||
if (hasAdvancedParams) {
|
||||
@@ -1786,31 +1925,74 @@ class AppStore {
|
||||
if (imageData && idx !== -1) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
const numImages = params.numImages;
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.messages[idx].content =
|
||||
`Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].content = progressText;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
this.messages[idx].attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
this.messages[idx].attachments = [
|
||||
...finals,
|
||||
partialAttachment,
|
||||
];
|
||||
}
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.messages[idx].content = "";
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
this.messages[idx].attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
this.messages[idx].attachments = [
|
||||
...previousFinals,
|
||||
newAttachment,
|
||||
];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
this.messages[idx].content =
|
||||
`Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
this.messages[idx].content = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
@@ -1855,6 +2037,8 @@ class AppStore {
|
||||
role: "user",
|
||||
content: prompt,
|
||||
timestamp: Date.now(),
|
||||
requestType: "image-editing",
|
||||
sourceImageDataUrl: imageDataUrl,
|
||||
};
|
||||
this.messages.push(userMessage);
|
||||
|
||||
@@ -1891,8 +2075,8 @@ class AppStore {
|
||||
formData.append("size", params.size);
|
||||
formData.append("output_format", params.outputFormat);
|
||||
formData.append("response_format", "b64_json");
|
||||
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
|
||||
formData.append("partial_images", "3");
|
||||
formData.append("stream", params.stream ? "1" : "0");
|
||||
formData.append("partial_images", params.partialImages.toString());
|
||||
formData.append("input_fidelity", params.inputFidelity);
|
||||
|
||||
// Advanced params
|
||||
@@ -2044,6 +2228,54 @@ class AppStore {
|
||||
this.conversations.find((c) => c.id === this.activeConversationId) || null
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a download on a specific node
|
||||
*/
|
||||
async startDownload(nodeId: string, shardMetadata: object): Promise<void> {
|
||||
try {
|
||||
const response = await fetch("/download/start", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
targetNodeId: nodeId,
|
||||
shardMetadata: shardMetadata,
|
||||
}),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Failed to start download: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error starting download:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a downloaded model from a specific node
|
||||
*/
|
||||
async deleteDownload(nodeId: string, modelId: string): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
},
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Failed to delete download: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting download:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const appStore = new AppStore();
|
||||
@@ -2098,6 +2330,10 @@ export const setSelectedChatModel = (modelId: string) =>
|
||||
appStore.setSelectedModel(modelId);
|
||||
export const selectPreviewModel = (modelId: string | null) =>
|
||||
appStore.selectPreviewModel(modelId);
|
||||
export const togglePreviewNodeFilter = (nodeId: string) =>
|
||||
appStore.togglePreviewNodeFilter(nodeId);
|
||||
export const clearPreviewNodeFilter = () => appStore.clearPreviewNodeFilter();
|
||||
export const previewNodeFilter = () => appStore.previewNodeFilter;
|
||||
export const deleteMessage = (messageId: string) =>
|
||||
appStore.deleteMessage(messageId);
|
||||
export const editMessage = (messageId: string, newContent: string) =>
|
||||
@@ -2134,6 +2370,10 @@ export const setChatSidebarVisible = (visible: boolean) =>
|
||||
appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
// Thunderbolt bridge status
|
||||
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
|
||||
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
|
||||
|
||||
// Image generation params
|
||||
export const imageGenerationParams = () => appStore.getImageGenerationParams();
|
||||
export const setImageGenerationParams = (
|
||||
@@ -2141,3 +2381,9 @@ export const setImageGenerationParams = (
|
||||
) => appStore.setImageGenerationParams(params);
|
||||
export const resetImageGenerationParams = () =>
|
||||
appStore.resetImageGenerationParams();
|
||||
|
||||
// Download actions
|
||||
export const startDownload = (nodeId: string, shardMetadata: object) =>
|
||||
appStore.startDownload(nodeId, shardMetadata);
|
||||
export const deleteDownload = (nodeId: string, modelId: string) =>
|
||||
appStore.deleteDownload(nodeId, modelId);
|
||||
|
||||
@@ -19,6 +19,9 @@
|
||||
selectedPreviewModelId,
|
||||
isLoadingPreviews,
|
||||
selectPreviewModel,
|
||||
togglePreviewNodeFilter,
|
||||
clearPreviewNodeFilter,
|
||||
previewNodeFilter,
|
||||
createConversation,
|
||||
setSelectedChatModel,
|
||||
selectedChatModel,
|
||||
@@ -28,6 +31,8 @@
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
thunderboltBridgeCycles,
|
||||
nodeThunderboltBridge,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview,
|
||||
} from "$lib/stores/app.svelte";
|
||||
@@ -49,6 +54,41 @@
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
const nodeFilter = $derived(previewNodeFilter());
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8) + "...";
|
||||
}
|
||||
|
||||
// Helper to get the thunderbolt bridge service name from a cycle
|
||||
function getTbBridgeServiceName(cycle: string[]): string {
|
||||
// Try to find service name from any node in the cycle
|
||||
for (const nodeId of cycle) {
|
||||
const nodeData = tbBridgeData?.[nodeId];
|
||||
if (nodeData?.serviceName) {
|
||||
return nodeData.serviceName;
|
||||
}
|
||||
}
|
||||
return "Thunderbolt Bridge"; // Fallback if no service name found
|
||||
}
|
||||
|
||||
// Copy to clipboard state and function
|
||||
let copiedCommand = $state(false);
|
||||
async function copyToClipboard(text: string) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
copiedCommand = true;
|
||||
setTimeout(() => {
|
||||
copiedCommand = false;
|
||||
}, 2000);
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
}
|
||||
}
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -90,6 +130,15 @@
|
||||
model.tasks.includes("ImageToImage")
|
||||
);
|
||||
}
|
||||
|
||||
// Helper to check if a model supports image editing
|
||||
function modelSupportsImageEditing(modelId: string): boolean {
|
||||
const model = models.find(
|
||||
(m) => m.id === modelId || m.hugging_face_id === modelId,
|
||||
);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes("ImageToImage");
|
||||
}
|
||||
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
|
||||
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
|
||||
@@ -181,6 +230,9 @@
|
||||
// Preview card hover state for highlighting nodes in topology
|
||||
let hoveredPreviewNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Computed: Check if filter is active (from store)
|
||||
const isFilterActive = $derived(() => nodeFilter.size > 0);
|
||||
|
||||
// Helper to unwrap tagged instance for hover highlighting
|
||||
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object")
|
||||
@@ -732,6 +784,8 @@
|
||||
instanceWrapped: unknown,
|
||||
): {
|
||||
isDownloading: boolean;
|
||||
isFailed: boolean;
|
||||
errorMessage: string | null;
|
||||
progress: DownloadProgress | null;
|
||||
statusText: string;
|
||||
perNode: Array<{
|
||||
@@ -743,6 +797,8 @@
|
||||
if (!downloadsData || Object.keys(downloadsData).length === 0) {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: "RUNNING",
|
||||
perNode: [],
|
||||
@@ -754,6 +810,8 @@
|
||||
if (!instance || typeof instance !== "object") {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: "PREPARING",
|
||||
perNode: [],
|
||||
@@ -809,6 +867,26 @@
|
||||
downloadKind
|
||||
] as Record<string, unknown>;
|
||||
|
||||
// Handle DownloadFailed - return immediately with error info
|
||||
if (downloadKind === "DownloadFailed") {
|
||||
const downloadModelId = extractModelIdFromDownload(downloadPayload);
|
||||
if (
|
||||
instanceModelId &&
|
||||
downloadModelId &&
|
||||
downloadModelId === instanceModelId
|
||||
) {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: true,
|
||||
errorMessage:
|
||||
(downloadPayload.errorMessage as string) || "Download failed",
|
||||
progress: null,
|
||||
statusText: "FAILED",
|
||||
perNode: [],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (downloadKind !== "DownloadOngoing") continue;
|
||||
if (!downloadPayload) continue;
|
||||
|
||||
@@ -844,6 +922,8 @@
|
||||
const statusInfo = deriveInstanceStatus(instanceWrapped);
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: statusInfo.statusText === "FAILED",
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: statusInfo.statusText,
|
||||
perNode: [],
|
||||
@@ -856,6 +936,8 @@
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
@@ -1451,7 +1533,7 @@
|
||||
|
||||
// Get ALL filtered previews based on current settings (matching minimum nodes)
|
||||
// Note: previewsData already contains previews for the selected model (fetched via API)
|
||||
// We filter by sharding/instance type and min nodes, returning ALL eligible previews
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
|
||||
const filteredPreviews = $derived(() => {
|
||||
if (!selectedModelId || previewsData.length === 0) return [];
|
||||
|
||||
@@ -1584,7 +1666,86 @@
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
@@ -1624,7 +1785,111 @@
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Node Filter Indicator (top-right corner) -->
|
||||
{#if isFilterActive()}
|
||||
<button
|
||||
onclick={clearPreviewNodeFilter}
|
||||
class="absolute top-2 right-2 flex items-center gap-1.5 px-2 py-1 bg-exo-dark-gray/80 border border-exo-yellow/40 rounded text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Clear filter"
|
||||
>
|
||||
<span class="text-[10px] font-mono tracking-wider">
|
||||
FILTER: {nodeFilter.size}
|
||||
</span>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Chat Input - Below topology -->
|
||||
@@ -2061,6 +2326,13 @@
|
||||
>
|
||||
{downloadInfo.statusText}
|
||||
</div>
|
||||
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
|
||||
<div
|
||||
class="text-xs text-red-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -2106,6 +2378,9 @@
|
||||
{@const isImageModel = modelSupportsImageGeneration(
|
||||
foundModel.id,
|
||||
)}
|
||||
{@const isImageEditModel = modelSupportsImageEditing(
|
||||
foundModel.id,
|
||||
)}
|
||||
<span
|
||||
class="flex items-center justify-between gap-2 w-full pr-4"
|
||||
>
|
||||
@@ -2132,6 +2407,22 @@
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{/if}
|
||||
{#if isImageEditModel}
|
||||
<svg
|
||||
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
|
||||
/>
|
||||
<path
|
||||
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{foundModel.name || foundModel.id}</span
|
||||
>
|
||||
@@ -2204,6 +2495,9 @@
|
||||
{@const isImageModel = modelSupportsImageGeneration(
|
||||
model.id,
|
||||
)}
|
||||
{@const isImageEditModel = modelSupportsImageEditing(
|
||||
model.id,
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
@@ -2244,6 +2538,23 @@
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{/if}
|
||||
{#if isImageEditModel}
|
||||
<svg
|
||||
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
aria-label="Image editing model"
|
||||
>
|
||||
<path
|
||||
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
|
||||
/>
|
||||
<path
|
||||
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span
|
||||
@@ -2564,7 +2875,36 @@
|
||||
<div
|
||||
class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden"
|
||||
>
|
||||
<TopologyGraph highlightedNodes={highlightedNodes()} />
|
||||
<TopologyGraph
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
title="Thunderbolt Bridge cycle detected - click to view details"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-yellow-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-yellow-200"
|
||||
>TB CYCLE</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</button>
|
||||
|
||||
@@ -2993,6 +3333,13 @@
|
||||
>
|
||||
{downloadInfo.statusText}
|
||||
</div>
|
||||
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
|
||||
<div
|
||||
class="text-xs text-red-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
type DownloadProgress,
|
||||
refreshState,
|
||||
lastUpdate as lastUpdateStore,
|
||||
startDownload,
|
||||
deleteDownload,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
@@ -28,6 +30,7 @@
|
||||
etaMs: number;
|
||||
status: "completed" | "downloading";
|
||||
files: FileProgress[];
|
||||
shardMetadata?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type NodeEntry = {
|
||||
@@ -172,33 +175,6 @@
|
||||
}
|
||||
|
||||
let downloadOverview = $state<NodeEntry[]>([]);
|
||||
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
|
||||
[],
|
||||
);
|
||||
|
||||
async function fetchModels() {
|
||||
try {
|
||||
const response = await fetch("/models");
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
models = data.data || [];
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch models:", error);
|
||||
}
|
||||
}
|
||||
|
||||
function getModelTotalBytes(
|
||||
modelId: string,
|
||||
downloadTotalBytes: number,
|
||||
): number {
|
||||
if (downloadTotalBytes > 0) return downloadTotalBytes;
|
||||
const model = models.find((m) => m.id === modelId);
|
||||
if (model?.storage_size_megabytes) {
|
||||
return model.storage_size_megabytes * 1024 * 1024;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
try {
|
||||
@@ -296,6 +272,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
// Extract shard_metadata for use with download actions
|
||||
const shardMetadata = (downloadPayload.shard_metadata ??
|
||||
downloadPayload.shardMetadata) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
|
||||
const entry: ModelEntry = {
|
||||
modelId,
|
||||
prettyName,
|
||||
@@ -312,6 +294,7 @@
|
||||
? "completed"
|
||||
: "downloading",
|
||||
files,
|
||||
shardMetadata,
|
||||
};
|
||||
|
||||
const existing = modelMap.get(modelId);
|
||||
@@ -373,7 +356,6 @@
|
||||
onMount(() => {
|
||||
// Ensure we fetch at least once when visiting downloads directly
|
||||
refreshState();
|
||||
fetchModels();
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -482,7 +464,7 @@
|
||||
{#if model.status !== "completed"}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(
|
||||
getModelTotalBytes(model.modelId, model.totalBytes),
|
||||
model.totalBytes,
|
||||
)}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -497,6 +479,52 @@
|
||||
>
|
||||
{pct.toFixed(1)}%
|
||||
</span>
|
||||
{#if model.status !== "completed" && model.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
onclick={() =>
|
||||
startDownload(node.nodeId, model.shardMetadata!)}
|
||||
title="Start download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
{#if model.status === "completed"}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-red-400 transition-colors"
|
||||
onclick={() =>
|
||||
deleteDownload(node.nodeId, model.modelId)}
|
||||
title="Delete download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
|
||||
284
src/exo/download/coordinator.py
Normal file
284
src/exo/download/coordinator.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import (
|
||||
RepoDownloadProgress,
|
||||
delete_model,
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.download_command_receiver as commands:
|
||||
async for cmd in commands:
|
||||
# Only process commands targeting this node
|
||||
if cmd.command.target_node_id != self.node_id:
|
||||
continue
|
||||
|
||||
match cmd.command:
|
||||
case StartDownload(shard_metadata=shard):
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Check if already downloading or complete
|
||||
if model_id in self.download_status:
|
||||
status = self.download_status[model_id]
|
||||
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
|
||||
logger.debug(
|
||||
f"Download for {model_id} already in progress or complete, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
# Check initial status from downloader
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
def _start_download_task(
|
||||
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Emit ongoing status
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal last_progress_time
|
||||
|
||||
if progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
# Clean up active download tracking
|
||||
if callback_shard.model_card.model_id in self.active_downloads:
|
||||
del self.active_downloads[callback_shard.model_card.model_id]
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
ongoing = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=callback_shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=ongoing)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_wrapper() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(shard)
|
||||
except Exception as e:
|
||||
logger.error(f"Download failed for {model_id}: {e}")
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed)
|
||||
)
|
||||
finally:
|
||||
if model_id in self.active_downloads:
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
task = asyncio.create_task(download_wrapper())
|
||||
self.active_downloads[model_id] = task
|
||||
|
||||
async def _delete_download(self, model_id: ModelId) -> None:
|
||||
# Cancel if active
|
||||
if model_id in self.active_downloads:
|
||||
logger.info(f"Cancelling active download for {model_id} before deletion")
|
||||
self.active_downloads[model_id].cancel()
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
# Delete from disk
|
||||
logger.info(f"Deleting model files for {model_id}")
|
||||
deleted = await delete_model(model_id)
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Successfully deleted model {model_id}")
|
||||
else:
|
||||
logger.warning(f"Model {model_id} was not found on disk")
|
||||
|
||||
# Emit pending status to reset UI state, then remove from local tracking
|
||||
if model_id in self.download_status:
|
||||
current_status = self.download_status[model_id]
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info(
|
||||
"DownloadCoordinator: Fetching and emitting existing download progress..."
|
||||
)
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info(
|
||||
"DownloadCoordinator: Done emitting existing download progress."
|
||||
)
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DownloadCoordinator: Error emitting existing download progress: {e}"
|
||||
)
|
||||
@@ -24,6 +24,13 @@ from pydantic import (
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -35,12 +42,27 @@ from exo.shared.types.worker.downloads import (
|
||||
RepoFileDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceAuthenticationError(Exception):
|
||||
"""Raised when HuggingFace returns 401/403 for a model download."""
|
||||
|
||||
|
||||
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
|
||||
token = await get_hf_token()
|
||||
if status_code == 401 and token is None:
|
||||
return (
|
||||
f"Model '{model_id}' requires authentication. "
|
||||
f"Set HF_TOKEN in the app's Advanced settings, set the HF_TOKEN environment variable, or run `hf auth login`. "
|
||||
f"Get a token at https://huggingface.co/settings/tokens"
|
||||
)
|
||||
elif status_code == 403:
|
||||
return (
|
||||
f"Access denied to '{model_id}'. "
|
||||
f"Please accept the model terms at https://huggingface.co/{model_id}"
|
||||
)
|
||||
else:
|
||||
return f"Authentication failed for '{model_id}' (HTTP {status_code})"
|
||||
|
||||
|
||||
def trim_etag(etag: str) -> str:
|
||||
@@ -147,6 +169,8 @@ async def fetch_file_list_with_retry(
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
@@ -167,6 +191,9 @@ async def _fetch_file_list(
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.get(url, headers=headers) as response,
|
||||
):
|
||||
if response.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
if response.status == 200:
|
||||
data_json = await response.text()
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
||||
@@ -256,6 +283,9 @@ async def file_meta(
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
@@ -279,6 +309,8 @@ async def download_file_with_retry(
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
raise e
|
||||
@@ -322,6 +354,9 @@ async def _download_file(
|
||||
):
|
||||
if r.status == 404:
|
||||
raise FileNotFoundError(f"File not found: {url}")
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
assert r.status in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status}"
|
||||
)
|
||||
@@ -463,7 +498,7 @@ async def download_shard(
|
||||
allow_patterns: list[str] | None = None,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=}")
|
||||
logger.debug(f"Downloading {shard.model_card.model_id=}")
|
||||
|
||||
revision = "main"
|
||||
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
|
||||
@@ -476,7 +511,7 @@ async def download_shard(
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
logger.debug(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
@@ -68,7 +68,11 @@ def get_hf_home() -> Path:
|
||||
|
||||
|
||||
async def get_hf_token() -> str | None:
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
"""Retrieve the Hugging Face token from HF_TOKEN env var or HF_HOME directory."""
|
||||
# Check environment variable first
|
||||
if token := os.environ.get("HF_TOKEN"):
|
||||
return token
|
||||
# Fall back to file-based token
|
||||
token_path = get_hf_home() / "token"
|
||||
if await aios.path.exists(token_path):
|
||||
async with aiofiles.open(token_path, "r") as f:
|
||||
@@ -3,13 +3,15 @@ from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
@@ -19,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -166,7 +168,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
print("Error downloading shard:", e)
|
||||
logger.error("Error downloading shard:", e)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
@@ -5,13 +5,13 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
@@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
from typing import Iterator, Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -12,6 +13,8 @@ from loguru import logger
|
||||
from pydantic import PositiveInt
|
||||
|
||||
import exo.routing.topics as topics
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
@@ -21,7 +24,6 @@ from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
@@ -29,6 +31,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
download_coordinator: DownloadCoordinator | None
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
@@ -36,6 +39,7 @@ class Node:
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -49,8 +53,26 @@ class Node:
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
@@ -58,6 +80,7 @@ class Node:
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
)
|
||||
else:
|
||||
@@ -67,11 +90,12 @@ class Node:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -99,13 +123,25 @@ class Node:
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(router, worker, election, er_recv, master, api, node_id)
|
||||
return cls(
|
||||
router,
|
||||
download_coordinator,
|
||||
worker,
|
||||
election,
|
||||
er_recv,
|
||||
master,
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
@@ -170,13 +206,27 @@ class Node:
|
||||
)
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
@@ -185,6 +235,10 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
@@ -226,6 +280,7 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -268,6 +323,11 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-downloads",
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Literal, cast
|
||||
from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -20,7 +22,10 @@ from loguru import logger
|
||||
from exo.master.image_store import ImageStore
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_IMAGE_CACHE_DIR, EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.constants import (
|
||||
EXO_IMAGE_CACHE_DIR,
|
||||
EXO_MAX_CHUNK_SIZE,
|
||||
)
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import (
|
||||
@@ -29,6 +34,7 @@ from exo.shared.models.model_cards import (
|
||||
ModelId,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
AdvancedImageParams,
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
BenchImageGenerationResponse,
|
||||
@@ -38,6 +44,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
@@ -55,19 +62,32 @@ from exo.shared.types.api import (
|
||||
PlaceInstanceParams,
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -93,7 +113,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -102,7 +122,19 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||
if isinstance(chunk, TokenChunk)
|
||||
else ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -131,12 +163,14 @@ class API:
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
@@ -162,8 +196,12 @@ class API:
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
] = {}
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
@@ -231,6 +269,8 @@ class API:
|
||||
self.app.get("/images/{image_id}")(self.get_image)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -310,11 +350,20 @@ class API:
|
||||
return placements[new_ids[0]]
|
||||
|
||||
async def get_placement_previews(
|
||||
self, model_id: ModelId
|
||||
self,
|
||||
model_id: ModelId,
|
||||
node_ids: Annotated[list[NodeId] | None, Query()] = None,
|
||||
) -> PlacementPreviewResponse:
|
||||
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
|
||||
previews: list[PlacementPreview] = []
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
|
||||
# Create filtered topology if node_ids specified
|
||||
if node_ids and len(node_ids) > 0:
|
||||
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
|
||||
else:
|
||||
topology = self.state.topology
|
||||
|
||||
if len(list(topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
@@ -327,9 +376,7 @@ class API:
|
||||
instance_combinations.extend(
|
||||
[
|
||||
(sharding, instance_meta, i)
|
||||
for i in range(
|
||||
1, len(list(self.state.topology.list_nodes())) + 1
|
||||
)
|
||||
for i in range(1, len(list(topology.list_nodes())) + 1)
|
||||
]
|
||||
)
|
||||
# TODO: PDD
|
||||
@@ -347,7 +394,7 @@ class API:
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=self.state.topology,
|
||||
topology=topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
except ValueError as exc:
|
||||
@@ -439,11 +486,13 @@ class API:
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
@@ -462,7 +511,8 @@ class API:
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
if command_id in self._chat_completion_queues:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
@@ -470,6 +520,7 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
@@ -498,11 +549,12 @@ class API:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
@@ -511,7 +563,18 @@ class API:
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
@@ -529,6 +592,7 @@ class API:
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -539,6 +603,7 @@ class API:
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
@@ -554,7 +619,19 @@ class API:
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -571,7 +648,7 @@ class API:
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
role="assistant", content=combined_text, tool_calls=tool_calls
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -729,7 +806,9 @@ class API:
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
self._image_generation_queues[command_id], recv = channel[
|
||||
ImageChunk | ErrorChunk
|
||||
]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
@@ -769,6 +848,7 @@ class API:
|
||||
# Yield partial image event (always use b64_json for partials)
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"image_index": chunk.image_index,
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"format": str(chunk.format),
|
||||
@@ -838,7 +918,9 @@ class API:
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
self._image_generation_queues[command_id], recv = channel[
|
||||
ImageChunk | ErrorChunk
|
||||
]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
@@ -956,6 +1038,9 @@ class API:
|
||||
stream: bool,
|
||||
partial_images: int,
|
||||
bench: bool,
|
||||
quality: Literal["high", "medium", "low"],
|
||||
output_format: Literal["png", "jpeg", "webp"],
|
||||
advanced_params: AdvancedImageParams | None,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
@@ -984,6 +1069,9 @@ class API:
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
bench=bench,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=advanced_params,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -994,7 +1082,6 @@ class API:
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
@@ -1019,12 +1106,22 @@ class API:
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
partial_images: str = Form("0"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
# Parse string form values to proper types
|
||||
stream_bool = stream.lower() in ("true", "1", "yes")
|
||||
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_advanced_params = AdvancedImageParams.model_validate_json(
|
||||
advanced_params
|
||||
)
|
||||
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
@@ -1036,6 +1133,9 @@ class API:
|
||||
stream=stream_bool,
|
||||
partial_images=partial_images_int,
|
||||
bench=False,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=parsed_advanced_params,
|
||||
)
|
||||
|
||||
if stream_bool and partial_images_int > 0:
|
||||
@@ -1066,8 +1166,18 @@ class API:
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_advanced_params = AdvancedImageParams.model_validate_json(
|
||||
advanced_params
|
||||
)
|
||||
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
@@ -1079,6 +1189,9 @@ class API:
|
||||
stream=False,
|
||||
partial_images=0,
|
||||
bench=True,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=parsed_advanced_params,
|
||||
)
|
||||
|
||||
return await self._collect_image_generation_with_stats(
|
||||
@@ -1148,27 +1261,26 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
if queue := self._image_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
queue = self._image_generation_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
if queue := self._chat_completion_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert not isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
@@ -1191,3 +1303,28 @@ class API:
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def _send_download(self, command: DownloadCommand):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def start_download(
|
||||
self, payload: StartDownloadParams
|
||||
) -> StartDownloadResponse:
|
||||
command = StartDownload(
|
||||
target_node_id=payload.target_node_id,
|
||||
shard_metadata=payload.shard_metadata,
|
||||
)
|
||||
await self._send_download(command)
|
||||
return StartDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def delete_download(
|
||||
self, node_id: NodeId, model_id: ModelId
|
||||
) -> DeleteDownloadResponse:
|
||||
command = DeleteDownload(
|
||||
target_node_id=node_id,
|
||||
model_id=ModelId(model_id),
|
||||
)
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
@@ -87,12 +87,12 @@ def place_instance(
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
|
||||
@@ -197,49 +197,6 @@ def get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
cycles = cycle_digraph.get_cycles()
|
||||
expected_length = len(list(cycle_digraph.list_nodes()))
|
||||
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
|
||||
if not cycles:
|
||||
if expected_length > 1:
|
||||
logger.warning(
|
||||
f"No cycles of length {expected_length} found even though chosen subgraph contained {expected_length} nodes"
|
||||
)
|
||||
return []
|
||||
|
||||
cycle = cycles[0]
|
||||
|
||||
get_thunderbolt = False
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycle):
|
||||
get_thunderbolt = True
|
||||
|
||||
logger.debug(f"Using thunderbolt cycle: {get_thunderbolt}")
|
||||
|
||||
hosts: list[Host] = []
|
||||
for i in range(len(cycle)):
|
||||
current_node = cycle.node_ids[i]
|
||||
next_node = cycle.node_ids[(i + 1) % len(cycle)]
|
||||
|
||||
for connection in cycle_digraph.get_all_connections_between(
|
||||
source=current_node, sink=next_node
|
||||
):
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
selected_cycle: list[NodeId],
|
||||
cycle_digraph: Topology,
|
||||
@@ -265,9 +222,6 @@ def get_mlx_jaccl_devices_matrix(
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i} and {node_j}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
)
|
||||
@@ -279,22 +233,11 @@ def _find_connection_ip(
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
) -> Generator[str, None, None]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str, node_network: NodeNetworkInfo
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
for interface in node_network.interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
yield connection.sink_multiaddr.ip_address
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
@@ -303,43 +246,25 @@ def _find_ip_prioritised(
|
||||
cycle_digraph: Topology,
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
Priority: ethernet > wifi > unknown > thunderbolt
|
||||
"""
|
||||
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {
|
||||
_find_interface_name_for_ip(
|
||||
ip, node_network.get(other_node_id, NodeNetworkInfo())
|
||||
): ip
|
||||
for ip, _ in ips
|
||||
if not ips:
|
||||
return None
|
||||
other_network = node_network.get(other_node_id, NodeNetworkInfo())
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
@@ -381,9 +306,6 @@ def get_mlx_ring_hosts_by_node(
|
||||
node_id, other_node_id, cycle_digraph, node_network
|
||||
)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
@@ -416,9 +338,6 @@ def get_mlx_jaccl_coordinators(
|
||||
if ip is not None:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
)
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
@@ -48,95 +44,3 @@ def test_http_exception_handler_formats_openai_style() -> None:
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
|
||||
|
||||
def test_image_chunk_with_error_fields() -> None:
|
||||
chunk = ImageChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
data="",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
finish_reason="error",
|
||||
error_message="Image generation failed",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Image generation failed"
|
||||
assert chunk.data == ""
|
||||
assert chunk.chunk_index == 0
|
||||
assert chunk.total_chunks == 1
|
||||
assert chunk.image_index == 0
|
||||
|
||||
|
||||
def test_image_chunk_without_error() -> None:
|
||||
chunk = ImageChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
data="base64encodeddata",
|
||||
chunk_index=0,
|
||||
total_chunks=1,
|
||||
image_index=0,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
assert chunk.data == "base64encodeddata"
|
||||
|
||||
@@ -3,7 +3,6 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
@@ -14,7 +13,7 @@ from exo.master.tests.conftest import (
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
NetworkInterfaceInfo,
|
||||
@@ -273,45 +272,6 @@ def test_get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip="169.254.0.1", port=1234),
|
||||
Host(ip="169.254.0.2", port=1234),
|
||||
Host(ip="169.254.0.3", port=1234),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.types.commands import ForwarderCommand
|
||||
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
@@ -45,3 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
|
||||
CONNECTION_MESSAGES = TypedTopic(
|
||||
"connection_messages", PublishPolicy.Never, ConnectionMessage
|
||||
)
|
||||
DOWNLOAD_COMMANDS = TypedTopic(
|
||||
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
@@ -46,6 +47,7 @@ from exo.utils.info_gatherer.info_gatherer import (
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
ThunderboltBridgeInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -225,6 +227,21 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
for key, value in state.node_thunderbolt.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_thunderbolt_bridge = {
|
||||
key: value
|
||||
for key, value in state.node_thunderbolt_bridge.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
# Only recompute cycles if the leaving node had TB bridge enabled
|
||||
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
leaving_node_had_tb_enabled = (
|
||||
leaving_node_status is not None and leaving_node_status.enabled
|
||||
)
|
||||
thunderbolt_bridge_cycles = (
|
||||
topology.get_thunderbolt_bridge_cycles(node_thunderbolt_bridge, node_network)
|
||||
if leaving_node_had_tb_enabled
|
||||
else [list(cycle) for cycle in state.thunderbolt_bridge_cycles]
|
||||
)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
@@ -235,6 +252,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"node_system": node_system,
|
||||
"node_network": node_network,
|
||||
"node_thunderbolt": node_thunderbolt,
|
||||
"node_thunderbolt_bridge": node_thunderbolt_bridge,
|
||||
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -312,6 +331,22 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
|
||||
case ThunderboltBridgeInfo():
|
||||
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
|
||||
**state.node_thunderbolt_bridge,
|
||||
event.node_id: info.status,
|
||||
}
|
||||
update["node_thunderbolt_bridge"] = new_tb_bridge
|
||||
# Only recompute cycles if the enabled status changed
|
||||
old_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
old_enabled = old_status.enabled if old_status else False
|
||||
new_enabled = info.status.enabled
|
||||
if old_enabled != new_enabled:
|
||||
update["thunderbolt_bridge_cycles"] = (
|
||||
topology.get_thunderbolt_bridge_cycles(
|
||||
new_tb_bridge, state.node_network
|
||||
)
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
@@ -49,3 +49,7 @@ LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
|
||||
|
||||
EXO_ENABLE_IMAGE_MODELS = (
|
||||
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -410,161 +411,166 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
|
||||
# "flux1-schnell": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23782357120),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "flux1-dev": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23802816640),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image-edit-2509": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
@@ -615,7 +621,7 @@ class ConfigData(BaseModel):
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
@@ -627,7 +633,7 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
|
||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
@@ -637,11 +643,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
|
||||
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
@@ -650,7 +656,7 @@ async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
|
||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -7,6 +7,11 @@ import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import (
|
||||
InterfaceType,
|
||||
NodeNetworkInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.topology import (
|
||||
Connection,
|
||||
Cycle,
|
||||
@@ -188,24 +193,25 @@ class Topology:
|
||||
cycles.append(Cycle(node_ids=[node_id]))
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[Cycle]:
|
||||
tb_edges = [
|
||||
def get_rdma_cycles(self) -> list[Cycle]:
|
||||
rdma_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
if isinstance(conn, RDMAConnection)
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
rdma_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = (
|
||||
rx.PyDiGraph()
|
||||
)
|
||||
rdma_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
for u, v, conn in rdma_edges:
|
||||
rdma_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycle_idxs = rx.simple_cycles(rdma_graph)
|
||||
cycles: list[Cycle] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
|
||||
cycle = Cycle(node_ids=[rdma_graph[idx] for idx in cycle_idx])
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
@@ -219,18 +225,83 @@ class Topology:
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
|
||||
def is_rdma_cycle(self, cycle: Cycle) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
continue
|
||||
has_tb = False
|
||||
has_rdma = False
|
||||
for edge in self._graph.get_all_edge_data(rid, neighbor_rid):
|
||||
if edge.is_thunderbolt():
|
||||
has_tb = True
|
||||
if isinstance(edge, RDMAConnection):
|
||||
has_rdma = True
|
||||
break
|
||||
if not has_tb:
|
||||
if not has_rdma:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_thunderbolt_bridge_cycles(
|
||||
self,
|
||||
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
) -> list[list[NodeId]]:
|
||||
"""
|
||||
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
|
||||
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
|
||||
2 or fewer nodes don't cause the broadcast storm problem.
|
||||
"""
|
||||
enabled_nodes = {
|
||||
node_id
|
||||
for node_id, status in node_tb_bridge_status.items()
|
||||
if status.enabled
|
||||
}
|
||||
|
||||
if len(enabled_nodes) < 3:
|
||||
return []
|
||||
|
||||
thunderbolt_ips = _get_ips_with_interface_type(
|
||||
enabled_nodes, node_network, "thunderbolt"
|
||||
)
|
||||
|
||||
# Build subgraph with only TB bridge enabled nodes and thunderbolt connections
|
||||
graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = rx.PyDiGraph()
|
||||
node_to_idx: dict[NodeId, int] = {}
|
||||
|
||||
for node_id in enabled_nodes:
|
||||
if node_id in self._vertex_indices:
|
||||
node_to_idx[node_id] = graph.add_node(node_id)
|
||||
|
||||
for u, v, conn in self._graph.weighted_edge_list():
|
||||
source_id, sink_id = self._graph[u], self._graph[v]
|
||||
if source_id not in node_to_idx or sink_id not in node_to_idx:
|
||||
continue
|
||||
# Include connection if it's over a thunderbolt interface
|
||||
if (
|
||||
isinstance(conn, SocketConnection)
|
||||
and conn.sink_multiaddr.ip_address in thunderbolt_ips
|
||||
):
|
||||
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
|
||||
if isinstance(conn, RDMAConnection):
|
||||
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
|
||||
|
||||
return [
|
||||
[graph[idx] for idx in cycle]
|
||||
for cycle in rx.simple_cycles(graph)
|
||||
if len(cycle) > 2
|
||||
]
|
||||
|
||||
|
||||
def _get_ips_with_interface_type(
|
||||
node_ids: set[NodeId],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
interface_type: InterfaceType,
|
||||
) -> set[str]:
|
||||
"""Get all IP addresses on interfaces of the specified type for the given nodes."""
|
||||
ips: set[str] = set()
|
||||
for node_id in node_ids:
|
||||
network_info = node_network.get(node_id, NodeNetworkInfo())
|
||||
for iface in network_info.interfaces:
|
||||
if iface.interface_type == interface_type:
|
||||
ips.add(iface.ip_address)
|
||||
return ips
|
||||
|
||||
@@ -7,10 +7,11 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -54,6 +55,18 @@ class ChatCompletionMessageText(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
index: int | None = None
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallItem
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: (
|
||||
@@ -61,7 +74,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
@@ -340,3 +353,16 @@ class ImageListItem(BaseModel, frozen=True):
|
||||
|
||||
class ImageListResponse(BaseModel, frozen=True):
|
||||
data: list[ImageListItem]
|
||||
|
||||
|
||||
class StartDownloadParams(CamelCaseModel):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class StartDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
@@ -8,24 +7,29 @@ from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
Token = "Token"
|
||||
Image = "Image"
|
||||
from .worker.runner_response import ToolCallItem
|
||||
|
||||
|
||||
class BaseChunk(TaggedModel):
|
||||
idx: int
|
||||
model: ModelId
|
||||
|
||||
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
error_message: str
|
||||
finish_reason: Literal["error"] = "error"
|
||||
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
@@ -63,4 +67,4 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
@@ -9,7 +9,7 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -62,6 +62,19 @@ class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
|
||||
class StartDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DeleteDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
|
||||
|
||||
Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
@@ -79,3 +92,8 @@ Command = (
|
||||
class ForwarderCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: Command
|
||||
|
||||
|
||||
class ForwarderDownloadCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: DownloadCommand
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
from typing import Literal, Self
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -48,9 +48,13 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
ecpu_usage: float = 0.0
|
||||
|
||||
|
||||
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
name: str
|
||||
ip_address: str
|
||||
interface_type: InterfaceType = "unknown"
|
||||
|
||||
|
||||
class NodeIdentity(CamelCaseModel):
|
||||
@@ -71,3 +75,11 @@ class NodeThunderboltInfo(CamelCaseModel):
|
||||
"""Thunderbolt interface identifiers for a node."""
|
||||
|
||||
interfaces: Sequence[ThunderboltIdentifier] = []
|
||||
|
||||
|
||||
class ThunderboltBridgeStatus(CamelCaseModel):
|
||||
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
|
||||
|
||||
enabled: bool
|
||||
exists: bool
|
||||
service_name: str | None = None
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
SystemPerformanceProfile,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
@@ -51,6 +52,10 @@ class State(CamelCaseModel):
|
||||
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
|
||||
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
|
||||
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
|
||||
|
||||
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
|
||||
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
|
||||
|
||||
@field_serializer("topology", mode="plain")
|
||||
def _encode_topology(self, value: Topology) -> TopologySnapshot:
|
||||
|
||||
@@ -21,9 +21,6 @@ class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
@@ -31,9 +28,6 @@ class SocketConnection(FrozenModel):
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
|
||||
class Connection(FrozenModel):
|
||||
source: NodeId
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import (
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -25,6 +30,7 @@ class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
stats: ImageGenerationStats | None = None
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
@@ -39,6 +45,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
@@ -48,5 +55,9 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
yield name, value
|
||||
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -50,9 +50,7 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
|
||||
@@ -19,6 +19,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnection,
|
||||
@@ -34,6 +35,142 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
async def _get_thunderbolt_devices() -> set[str] | None:
|
||||
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listallhardwareports failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
output = result.stdout.decode()
|
||||
thunderbolt_devices: set[str] = set()
|
||||
current_port: str | None = None
|
||||
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("Hardware Port:"):
|
||||
current_port = line.split(":", 1)[1].strip()
|
||||
elif line.startswith("Device:") and current_port:
|
||||
device = line.split(":", 1)[1].strip()
|
||||
if "thunderbolt" in current_port.lower():
|
||||
thunderbolt_devices.add(device)
|
||||
current_port = None
|
||||
|
||||
return thunderbolt_devices
|
||||
|
||||
|
||||
async def _get_bridge_services() -> dict[str, str] | None:
|
||||
"""Get mapping of bridge device -> service name from network service order.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listnetworkserviceorder"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listnetworkserviceorder failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse service order to find bridge devices and their service names
|
||||
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
|
||||
service_order_output = result.stdout.decode()
|
||||
bridge_services: dict[str, str] = {} # device -> service name
|
||||
current_service: str | None = None
|
||||
|
||||
for line in service_order_output.splitlines():
|
||||
line = line.strip()
|
||||
# Match "(N) Service Name" or "(*) Service Name" (disabled)
|
||||
# but NOT "(Hardware Port: ...)" lines
|
||||
if (
|
||||
line
|
||||
and line.startswith("(")
|
||||
and ")" in line
|
||||
and not line.startswith("(Hardware Port:")
|
||||
):
|
||||
paren_end = line.index(")")
|
||||
if paren_end + 2 <= len(line):
|
||||
current_service = line[paren_end + 2 :]
|
||||
# Match "(Hardware Port: ..., Device: bridgeX)"
|
||||
elif current_service and "Device: bridge" in line:
|
||||
# Extract device name from "..., Device: bridge0)"
|
||||
device_start = line.find("Device: ") + len("Device: ")
|
||||
device_end = line.find(")", device_start)
|
||||
if device_end > device_start:
|
||||
device = line[device_start:device_end]
|
||||
bridge_services[device] = current_service
|
||||
|
||||
return bridge_services
|
||||
|
||||
|
||||
async def _get_bridge_members(bridge_device: str) -> set[str]:
|
||||
"""Get member interfaces of a bridge device via ifconfig."""
|
||||
result = await anyio.run_process(
|
||||
["ifconfig", bridge_device],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
|
||||
return set()
|
||||
|
||||
members: set[str] = set()
|
||||
ifconfig_output = result.stdout.decode()
|
||||
for line in ifconfig_output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("member:"):
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
members.add(parts[1])
|
||||
|
||||
return members
|
||||
|
||||
|
||||
async def _find_thunderbolt_bridge(
|
||||
bridge_services: dict[str, str], thunderbolt_devices: set[str]
|
||||
) -> str | None:
|
||||
"""Find the service name of a bridge containing Thunderbolt interfaces.
|
||||
|
||||
Returns the service name if found, None otherwise.
|
||||
"""
|
||||
for bridge_device, service_name in bridge_services.items():
|
||||
members = await _get_bridge_members(bridge_device)
|
||||
if members & thunderbolt_devices: # intersection is non-empty
|
||||
return service_name
|
||||
return None
|
||||
|
||||
|
||||
async def _is_service_enabled(service_name: str) -> bool | None:
|
||||
"""Check if a network service is enabled.
|
||||
|
||||
Returns True if enabled, False if disabled, None on error.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-getnetworkserviceenabled", service_name],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -getnetworkserviceenabled '{service_name}' "
|
||||
f"failed with code {result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
stdout = result.stdout.decode().strip().lower()
|
||||
return stdout == "enabled"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
@@ -58,6 +195,66 @@ class MacThunderboltConnections(TaggedModel):
|
||||
conns: Sequence[ThunderboltConnection]
|
||||
|
||||
|
||||
class ThunderboltBridgeInfo(TaggedModel):
|
||||
status: ThunderboltBridgeStatus
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
"""Check if a Thunderbolt Bridge network service is enabled on this node.
|
||||
|
||||
Detection approach:
|
||||
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
|
||||
2. Find bridge devices from network service order (not hardware ports, as
|
||||
bridges may not appear there)
|
||||
3. Check each bridge's members via ifconfig
|
||||
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
|
||||
5. Check if that network service is enabled
|
||||
"""
|
||||
if not IS_DARWIN:
|
||||
return None
|
||||
|
||||
def _no_bridge_status() -> Self:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=False, service_name=None
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
tb_devices = await _get_thunderbolt_devices()
|
||||
if tb_devices is None:
|
||||
return _no_bridge_status()
|
||||
|
||||
bridge_services = await _get_bridge_services()
|
||||
if not bridge_services:
|
||||
return _no_bridge_status()
|
||||
|
||||
tb_service_name = await _find_thunderbolt_bridge(
|
||||
bridge_services, tb_devices
|
||||
)
|
||||
if not tb_service_name:
|
||||
return _no_bridge_status()
|
||||
|
||||
enabled = await _is_service_enabled(tb_service_name)
|
||||
if enabled is None:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=True, service_name=tb_service_name
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=enabled,
|
||||
exists=True,
|
||||
service_name=tb_service_name,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
@@ -111,6 +308,7 @@ GatheredInfo = (
|
||||
| NodeNetworkInterfaces
|
||||
| MacThunderboltIdentifiers
|
||||
| MacThunderboltConnections
|
||||
| ThunderboltBridgeInfo
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
@@ -125,6 +323,7 @@ class InfoGatherer:
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
@@ -133,6 +332,7 @@ class InfoGatherer:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
|
||||
tg.start_soon(self._monitor_thunderbolt_bridge_status)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
@@ -200,12 +400,23 @@ class InfoGatherer:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
nics = await get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
|
||||
@@ -5,7 +5,7 @@ from subprocess import CalledProcessError
|
||||
import psutil
|
||||
from anyio import run_process
|
||||
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo
|
||||
|
||||
|
||||
async def get_friendly_name() -> str:
|
||||
@@ -16,8 +16,7 @@ async def get_friendly_name() -> str:
|
||||
"""
|
||||
hostname = socket.gethostname()
|
||||
|
||||
# TODO: better non mac support
|
||||
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
|
||||
if sys.platform != "darwin":
|
||||
return hostname
|
||||
|
||||
try:
|
||||
@@ -28,7 +27,41 @@ async def get_friendly_name() -> str:
|
||||
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
|
||||
|
||||
|
||||
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
"""Parse networksetup -listallhardwareports to get interface types."""
|
||||
if sys.platform != "darwin":
|
||||
return {}
|
||||
|
||||
try:
|
||||
result = await run_process(["networksetup", "-listallhardwareports"])
|
||||
except CalledProcessError:
|
||||
return {}
|
||||
|
||||
types: dict[str, InterfaceType] = {}
|
||||
current_type: InterfaceType = "unknown"
|
||||
|
||||
for line in result.stdout.decode().splitlines():
|
||||
if line.startswith("Hardware Port:"):
|
||||
port_name = line.split(":", 1)[1].strip()
|
||||
if "Wi-Fi" in port_name:
|
||||
current_type = "wifi"
|
||||
elif "Ethernet" in port_name or "LAN" in port_name:
|
||||
current_type = "ethernet"
|
||||
elif port_name.startswith("Thunderbolt"):
|
||||
current_type = "thunderbolt"
|
||||
else:
|
||||
current_type = "unknown"
|
||||
elif line.startswith("Device:"):
|
||||
device = line.split(":", 1)[1].strip()
|
||||
# enX is ethernet adapters or thunderbolt - these must be deprioritised
|
||||
if device.startswith("en") and device not in ["en0", "en1"]:
|
||||
current_type = "maybe_ethernet"
|
||||
types[device] = current_type
|
||||
|
||||
return types
|
||||
|
||||
|
||||
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
"""
|
||||
Retrieves detailed network interface information on macOS.
|
||||
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
|
||||
@@ -36,13 +69,18 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
Returns a list of NetworkInterfaceInfo objects.
|
||||
"""
|
||||
interfaces_info: list[NetworkInterfaceInfo] = []
|
||||
interface_types = await _get_interface_types_from_networksetup()
|
||||
|
||||
for iface, services in psutil.net_if_addrs().items():
|
||||
for service in services:
|
||||
match service.family:
|
||||
case socket.AF_INET | socket.AF_INET6:
|
||||
interfaces_info.append(
|
||||
NetworkInterfaceInfo(name=iface, ip_address=service.address)
|
||||
NetworkInterfaceInfo(
|
||||
name=iface,
|
||||
ip_address=service.address,
|
||||
interface_type=interface_types.get(iface, "unknown"),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
|
||||
32
src/exo/utils/keyed_backoff.py
Normal file
32
src/exo/utils/keyed_backoff.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import time
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
class KeyedBackoff(Generic[K]):
|
||||
"""Tracks exponential backoff state per key."""
|
||||
|
||||
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
||||
self._base = base
|
||||
self._cap = cap
|
||||
self._attempts: dict[K, int] = {}
|
||||
self._last_time: dict[K, float] = {}
|
||||
|
||||
def should_proceed(self, key: K) -> bool:
|
||||
"""Returns True if enough time has elapsed since last attempt."""
|
||||
now = time.monotonic()
|
||||
last = self._last_time.get(key, 0.0)
|
||||
attempts = self._attempts.get(key, 0)
|
||||
delay = min(self._cap, self._base * (2.0**attempts))
|
||||
return now - last >= delay
|
||||
|
||||
def record_attempt(self, key: K) -> None:
|
||||
"""Record that an attempt was made for this key."""
|
||||
self._last_time[key] = time.monotonic()
|
||||
self._attempts[key] = self._attempts.get(key, 0) + 1
|
||||
|
||||
def reset(self, key: K) -> None:
|
||||
"""Reset backoff state for a key (e.g., on success)."""
|
||||
self._attempts.pop(key, None)
|
||||
self._last_time.pop(key, None)
|
||||
@@ -6,10 +6,10 @@ import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
|
||||
@@ -75,19 +75,20 @@ def generate_image(
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
advanced_params = task.advanced_params
|
||||
if advanced_params is not None and advanced_params.seed is not None:
|
||||
seed = advanced_params.seed
|
||||
base_seed = advanced_params.seed
|
||||
else:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
base_seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
is_bench = getattr(task, "bench", False)
|
||||
num_images = task.n or 1
|
||||
|
||||
generation_start_time: float = 0.0
|
||||
|
||||
@@ -95,7 +96,11 @@ def generate_image(
|
||||
mx.reset_peak_memory()
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
partial_images = (
|
||||
task.partial_images
|
||||
if task.partial_images is not None
|
||||
else (3 if task.stream else 0)
|
||||
)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
@@ -105,72 +110,81 @@ def generate_image(
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
current_seed = base_seed + image_num
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=current_seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = generation_end_time - generation_start_time
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / num_inference_steps
|
||||
if num_inference_steps > 0
|
||||
else 0.0
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
image_index=image_num,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
# Only include stats on the final image
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench and image_num == num_images - 1:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = (
|
||||
generation_end_time - generation_start_time
|
||||
)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=task.n or 1,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
total_steps = num_inference_steps * num_images
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / total_steps
|
||||
if total_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
image_index=image_num,
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
@@ -1,302 +0,0 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, ChatCompletionTaskParams]
|
||||
] = []
|
||||
self._pending_completions: list[
|
||||
int
|
||||
] = [] # UIDs completed but not yet synced/removed
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion. Only rank 0 should call this.
|
||||
|
||||
In distributed mode, rank 0 receives tasks from the control plane and
|
||||
queues them here. The actual insertion happens in sync_and_insert_pending()
|
||||
which ensures all ranks insert the same requests together.
|
||||
"""
|
||||
assert self.rank == 0, "Only rank 0 should queue requests"
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
logger.info(
|
||||
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})"
|
||||
)
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Sync pending inserts across ranks and insert them. Returns UIDs.
|
||||
|
||||
This method ensures all ranks insert the same requests in the same order.
|
||||
In non-distributed mode, it simply inserts all pending requests.
|
||||
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
|
||||
|
||||
Batches all pending inserts into a single batch_gen.insert() call for
|
||||
efficient prefill batching.
|
||||
"""
|
||||
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
|
||||
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: just insert directly from pending
|
||||
inserts_to_process = list(self._pending_inserts)
|
||||
else:
|
||||
# Distributed: broadcast pending inserts from rank 0 to all ranks
|
||||
assert self.group is not None
|
||||
pending_data = self._pending_inserts if self.rank == 0 else None
|
||||
synced_data = share_object(pending_data, self.rank, self.group)
|
||||
|
||||
if synced_data is None:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
inserts_to_process = synced_data
|
||||
|
||||
if not inserts_to_process:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
# Prepare all requests for batched insertion
|
||||
all_tokens: list[list[int]] = []
|
||||
all_max_tokens: list[int] = []
|
||||
all_prompt_tokens: list[int] = []
|
||||
request_info: list[tuple[CommandId, TaskId]] = []
|
||||
|
||||
for cmd_id, task_id, params in inserts_to_process:
|
||||
prompt_str = apply_chat_template(self.tokenizer, params)
|
||||
tokens: list[int] = self.tokenizer.encode(
|
||||
prompt_str, add_special_tokens=False
|
||||
)
|
||||
max_tokens = params.max_tokens or self.max_tokens
|
||||
|
||||
all_tokens.append(tokens)
|
||||
all_max_tokens.append(max_tokens)
|
||||
all_prompt_tokens.append(len(tokens))
|
||||
request_info.append((cmd_id, task_id))
|
||||
|
||||
# Single batched insert for efficient prefill
|
||||
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
|
||||
|
||||
# Track all inserted requests
|
||||
for i, uid in enumerate(uids):
|
||||
cmd_id, task_id = request_info[i]
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
uid=uid,
|
||||
detokenizer=self.tokenizer.detokenizer,
|
||||
prompt_tokens=all_prompt_tokens[i],
|
||||
)
|
||||
logger.info(
|
||||
f"Inserted request {cmd_id} with uid={uid}, prompt_tokens={all_prompt_tokens[i]}, max_tokens={all_max_tokens[i]}"
|
||||
)
|
||||
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif raw_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
# Track completion but don't remove yet - wait for sync_completions()
|
||||
self._pending_completions.append(uid)
|
||||
logger.info(
|
||||
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(
|
||||
text=text, token=token, finish_reason=finish_reason, stats=stats
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# In non-distributed mode, clean up completions immediately
|
||||
if not self.is_distributed:
|
||||
self._remove_completed()
|
||||
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: early return if nothing to do
|
||||
if not self._pending_completions:
|
||||
return
|
||||
self._remove_completed()
|
||||
return
|
||||
|
||||
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
|
||||
# This prevents deadlock if one rank has completions and another doesn't
|
||||
assert self.group is not None
|
||||
synced_uids = share_object(
|
||||
self._pending_completions if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
if synced_uids:
|
||||
self._pending_completions = synced_uids
|
||||
|
||||
self._remove_completed()
|
||||
|
||||
def _remove_completed(self) -> None:
|
||||
"""Remove completed requests from tracking."""
|
||||
for uid in self._pending_completions:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
self._pending_completions.clear()
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def has_pending_completions(self) -> bool:
|
||||
return bool(self._pending_completions)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
|
||||
if rank == 0:
|
||||
if obj is None:
|
||||
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
|
||||
return None
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
return obj
|
||||
else:
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
if size == 0:
|
||||
return None
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Time budget iterator for controlling generation loop timing in distributed mode.
|
||||
|
||||
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
|
||||
rather than syncing every token. This reduces distributed sync overhead.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class TimeBudget(Iterator[None]):
|
||||
"""Controls generation loop timing, syncing across ranks periodically.
|
||||
|
||||
In distributed mode, periodically syncs timing across all ranks to
|
||||
dynamically adjust iteration count based on actual performance.
|
||||
|
||||
In non-distributed mode, simply runs for the time budget.
|
||||
|
||||
Usage:
|
||||
for _ in TimeBudget(budget=0.5):
|
||||
batch_engine.step()
|
||||
# ... process responses ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
budget: float = 0.5,
|
||||
iterations: int = 25,
|
||||
sync_frequency: int = 10,
|
||||
group: mx.distributed.Group | None = None,
|
||||
):
|
||||
"""Initialize TimeBudget.
|
||||
|
||||
Args:
|
||||
budget: Time budget in seconds before yielding control
|
||||
iterations: Initial number of iterations per budget period (distributed only)
|
||||
sync_frequency: How often to sync timing across ranks (distributed only)
|
||||
group: Distributed group, or None for non-distributed mode
|
||||
"""
|
||||
self._budget = budget
|
||||
self._iterations = iterations
|
||||
self._sync_frequency = sync_frequency
|
||||
self._group = group
|
||||
self._is_distributed = group is not None and group.size() > 1
|
||||
|
||||
# Runtime state
|
||||
self._start: float = 0.0
|
||||
self._current_iterations: int = 0
|
||||
self._loops: int = 0
|
||||
self._time_spent: float = 0.0
|
||||
|
||||
def __iter__(self) -> "TimeBudget":
|
||||
self._start = time.perf_counter()
|
||||
self._current_iterations = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> None:
|
||||
if not self._is_distributed:
|
||||
# Non-distributed: just check time budget
|
||||
if time.perf_counter() - self._start > self._budget:
|
||||
raise StopIteration()
|
||||
return None
|
||||
|
||||
# Distributed mode: iteration-based with periodic timing sync
|
||||
self._current_iterations += 1
|
||||
if self._current_iterations > self._iterations:
|
||||
self._loops += 1
|
||||
self._time_spent += time.perf_counter() - self._start
|
||||
|
||||
if self._loops % self._sync_frequency == 0:
|
||||
# Sync timing across all ranks
|
||||
assert self._group is not None
|
||||
with mx.stream(generation_stream):
|
||||
time_array = mx.array([self._time_spent], dtype=mx.float32)
|
||||
total_time = mx.distributed.all_sum(time_array, group=self._group)
|
||||
mx.eval(total_time)
|
||||
loop_time = float(total_time.item())
|
||||
|
||||
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
|
||||
|
||||
if avg_loop_time > 0:
|
||||
factor = self._budget / avg_loop_time
|
||||
self._iterations = max(round(self._iterations * factor), 1)
|
||||
logger.debug(
|
||||
f"TimeBudget adjusted iterations to {self._iterations}"
|
||||
)
|
||||
|
||||
self._loops = 0
|
||||
self._time_spent = 0.0
|
||||
|
||||
raise StopIteration()
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Current iterations per budget period."""
|
||||
return self._iterations
|
||||
@@ -41,6 +41,7 @@ import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
from pydantic import RootModel
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -55,7 +56,6 @@ from exo.shared.types.worker.shards import (
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
TimeoutCallback,
|
||||
@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
"""
|
||||
Normalize tool_calls in a message dict.
|
||||
|
||||
OpenAI format has tool_calls[].function.arguments as a JSON string,
|
||||
but some chat templates (e.g., GLM) expect it as a dict.
|
||||
"""
|
||||
tool_calls = msg_dict.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
return
|
||||
|
||||
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if not isinstance(func, dict):
|
||||
continue
|
||||
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if isinstance(args, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
chat_task_data: ChatCompletionTaskParams,
|
||||
) -> str:
|
||||
# Now we can properly access the messages
|
||||
messages = chat_task_data.messages
|
||||
tools = chat_task_data.tools
|
||||
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
@@ -386,15 +409,19 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
formatted_messages.append(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
dumped: dict[str, Any] = message.model_dump()
|
||||
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
|
||||
|
||||
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
|
||||
_normalize_tool_calls(msg_dict)
|
||||
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, current_time, fail_after
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
@@ -10,7 +11,12 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
RequestEventLog,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
@@ -18,7 +24,6 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -36,22 +41,12 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -61,7 +56,6 @@ class Worker:
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
shard_downloader: ShardDownloader,
|
||||
*,
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
@@ -69,23 +63,22 @@ class Worker:
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.shard_downloader: ShardDownloader = shard_downloader
|
||||
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.local_event_index = 0
|
||||
self.event_index_counter = event_index_counter
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.connection_message_receiver = connection_message_receiver
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
|
||||
@@ -100,6 +93,8 @@ class Worker:
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -110,7 +105,6 @@ class Worker:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -120,6 +114,7 @@ class Worker:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
@@ -178,11 +173,9 @@ class Worker:
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
# 3. based on the updated state, we plan & execute an operation.
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.download_status,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
@@ -206,42 +199,26 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_card.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
model_id = shard.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -386,78 +363,17 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
initial_progress: RepoDownloadProgress,
|
||||
):
|
||||
"""Manages the shard download process with progress tracking."""
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=self.local_event_index,
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
|
||||
)
|
||||
self.local_event_index += 1
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
@@ -505,42 +421,3 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.debug("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
@@ -20,6 +19,7 @@ from exo.shared.types.tasks import (
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
@@ -44,9 +44,6 @@ def plan(
|
||||
node_id: NodeId,
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
# DL_status is expected to be FRESH and so should not come from state
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
# gdls is not expected to be fresh
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
@@ -58,7 +55,7 @@ def plan(
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
@@ -114,15 +111,22 @@ def _create_runner(
|
||||
|
||||
|
||||
def _model_needs_download(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> DownloadModel | None:
|
||||
local_downloads = global_download_status.get(node_id, [])
|
||||
download_status = {
|
||||
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
download_status[model_id],
|
||||
(DownloadOngoing, DownloadCompleted, DownloadFailed),
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
@@ -291,14 +295,12 @@ def _pending_tasks(
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -105,7 +105,7 @@ class RunnerSupervisor:
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown successfully, terminating")
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -128,11 +128,9 @@ class RunnerSupervisor:
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(f"Skipping task {task.task_id} - already completed")
|
||||
return
|
||||
if task.task_id in self.pending:
|
||||
logger.info(f"Skipping task {task.task_id} - already pending")
|
||||
return
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -151,17 +149,13 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
# Just set the event to unblock start_task, but keep in pending
|
||||
# to prevent duplicate forwarding until completion
|
||||
if event.task_id in self.pending:
|
||||
self.pending[event.task_id].set()
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if isinstance(event, TaskStatusUpdated) and event.task_status in (
|
||||
TaskStatus.Complete,
|
||||
TaskStatus.TimedOut,
|
||||
TaskStatus.Failed,
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
):
|
||||
# If a task has just finished, we should be working on it.
|
||||
# If a task has just been completed, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
@@ -172,8 +166,6 @@ class RunnerSupervisor:
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
# Now safe to remove from pending and add to completed
|
||||
self.pending.pop(event.task_id, None)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
|
||||
@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -11,12 +11,12 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import ModelId, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
@@ -45,13 +45,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ModelId, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=download_status,
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -92,14 +88,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -116,7 +104,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -148,23 +135,19 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -202,12 +185,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -220,7 +197,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -245,7 +221,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -47,8 +47,7 @@ def test_plan_kills_runner_when_instance_missing():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -87,8 +86,7 @@ def test_plan_kills_runner_when_sibling_failed():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -120,7 +118,6 @@ def test_plan_creates_runner_when_missing_for_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners,
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -158,8 +155,7 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -189,7 +185,6 @@ def test_plan_does_not_create_runner_for_unassigned_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -65,7 +65,6 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -113,7 +112,6 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -158,7 +156,6 @@ def test_plan_does_not_forward_tasks_for_other_instances():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -221,7 +218,6 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -261,7 +257,6 @@ def test_plan_returns_none_when_nothing_to_do():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -57,7 +57,6 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -99,7 +98,6 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -140,7 +138,6 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -185,7 +182,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -202,7 +198,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -246,7 +241,6 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -289,7 +283,6 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -331,7 +324,6 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, ChatCompletionTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._do_insert(command_id, task_id, task_params)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams | None,
|
||||
) -> int:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False # Non-distributed mode for testing
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
"""Fake MLX distributed group for testing."""
|
||||
|
||||
def size(self) -> int:
|
||||
return 1 # Single node (non-distributed)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender, event_receiver:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
event_sender.close = lambda: None
|
||||
event_sender.join = lambda: None
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> ChatCompletion:
|
||||
return ChatCompletion(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Note: With the current non-blocking design, shutdown is processed before
|
||||
batch steps run when all tasks are queued together. This test verifies
|
||||
the runner status reflects active requests.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events - this shows the request was inserted
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_tracked(patch_batch_engine: None):
|
||||
"""Verify multiple concurrent requests are tracked in active_requests."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
# Should have at least 2 RunnerRunning events (one per request inserted)
|
||||
assert len(running_events) >= 2, (
|
||||
f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
|
||||
)
|
||||
|
||||
# First should have 1 active request, second should have 2
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
assert running_events[1].runner_status.active_requests == 2
|
||||
@@ -1,17 +1,12 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable
|
||||
from unittest.mock import MagicMock
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -27,7 +22,6 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
@@ -44,9 +38,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
@@ -116,100 +107,22 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
class FakeBatchEngine:
|
||||
"""
|
||||
Fake batch engine for testing.
|
||||
|
||||
Queues requests on insert, returns one token per step.
|
||||
The runner's non-blocking loop drains all tasks before running batch steps,
|
||||
so this engine queues requests and has_active_requests returns True only
|
||||
after at least one request has been inserted.
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, ChatCompletionTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, _task_params in self._pending_inserts:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
self._active_requests[uid] = (command_id, task_id)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
# Process all active requests - return one token and complete
|
||||
for uid, (command_id, task_id) in list(self._active_requests.items()):
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0,
|
||||
text="hi",
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
)
|
||||
del self._active_requests[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False # Non-distributed mode for testing
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
"""Fake MLX distributed group for testing."""
|
||||
|
||||
def size(self) -> int:
|
||||
return 1 # Single node (non-distributed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a fake "group" (non-None for state machine)
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
|
||||
)
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
# Use a fake event_sender to remove test flakiness.
|
||||
@@ -227,6 +140,13 @@ class EventCollector:
|
||||
pass
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -252,14 +172,12 @@ def _run(tasks: Iterable[Task]):
|
||||
return event_sender.events
|
||||
|
||||
|
||||
def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPatch):
|
||||
"""Verify chat completion generates tokens, completes, and runner returns to Ready."""
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
command_id=COMMAND_1_ID,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="hi",
|
||||
token_id=0,
|
||||
@@ -296,9 +214,7 @@ def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPat
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerRunning(active_requests=1)
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -313,6 +229,7 @@ def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPat
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -11,6 +11,10 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
@@ -36,10 +40,6 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user