mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-10 06:59:49 -05:00
Compare commits
11 Commits
alexcheema
...
v1.0.61-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f6fcd9e93 | ||
|
|
839b67f318 | ||
|
|
47b8e0ce12 | ||
|
|
17f9b583a4 | ||
|
|
844bcc7ce6 | ||
|
|
c1be5184b2 | ||
|
|
1ec550dff1 | ||
|
|
283c0e39e4 | ||
|
|
35be4c55c3 | ||
|
|
31d4cd8409 | ||
|
|
8a6da58404 |
@@ -20,6 +20,8 @@ struct ContentView: View {
|
||||
@State private var showDebugInfo = false
|
||||
@State private var bugReportInFlight = false
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var showAdvancedOptions = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
@@ -197,6 +199,8 @@ struct ContentView: View {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
.padding(.bottom, 8)
|
||||
advancedOptionsSection
|
||||
.padding(.bottom, 8)
|
||||
debugSection
|
||||
.padding(.bottom, 8)
|
||||
controlButton(title: "Quit", tint: .secondary) {
|
||||
@@ -327,6 +331,47 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var advancedOptionsSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Advanced Options")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showAdvancedOptions)
|
||||
}
|
||||
.animation(nil, value: showAdvancedOptions)
|
||||
if showAdvancedOptions {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Cluster Namespace")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
TextField("optional", text: $pendingNamespace)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingNamespace = controller.customNamespace
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.customNamespace = pendingNamespace
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
|
||||
}
|
||||
|
||||
private var debugSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
|
||||
@@ -2,6 +2,8 @@ import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -27,6 +29,13 @@ final class ExoProcessController: ObservableObject {
|
||||
@Published private(set) var status: Status = .stopped
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var launchCountdownSeconds: Int?
|
||||
@Published var customNamespace: String = {
|
||||
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
|
||||
}() {
|
||||
didSet {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -180,7 +189,7 @@ final class ExoProcessController: ObservableObject {
|
||||
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
@@ -217,6 +226,12 @@ final class ExoProcessController: ObservableObject {
|
||||
}
|
||||
return "dev"
|
||||
}
|
||||
|
||||
private func computeNamespace() -> String {
|
||||
let base = buildTag()
|
||||
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
|
||||
return custom.isEmpty ? base : custom
|
||||
}
|
||||
}
|
||||
|
||||
struct RuntimeError: LocalizedError {
|
||||
|
||||
@@ -139,6 +139,11 @@
|
||||
}
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
|
||||
if (event.isComposing || event.keyCode === 229) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSubmit();
|
||||
|
||||
@@ -51,6 +51,59 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
// Launch defaults persistence
|
||||
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
|
||||
interface LaunchDefaults {
|
||||
modelId: string | null;
|
||||
sharding: 'Pipeline' | 'Tensor';
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
}
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
};
|
||||
try {
|
||||
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
|
||||
} catch (e) {
|
||||
console.warn('Failed to save launch defaults:', e);
|
||||
}
|
||||
}
|
||||
|
||||
function loadLaunchDefaults(): LaunchDefaults | null {
|
||||
try {
|
||||
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
|
||||
if (!stored) return null;
|
||||
return JSON.parse(stored) as LaunchDefaults;
|
||||
} catch (e) {
|
||||
console.warn('Failed to load launch defaults:', e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
}
|
||||
|
||||
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let minNodesInitialized = $state(false);
|
||||
@@ -298,6 +351,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const data = await response.json();
|
||||
// API returns { data: [{ id, name }] } format
|
||||
models = data.data || [];
|
||||
// Restore last launch defaults if available
|
||||
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
|
||||
applyLaunchDefaults(models, currentNodeCount);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch models:', error);
|
||||
@@ -988,6 +1044,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderMouseUp() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
// Handle touch events for mobile
|
||||
@@ -1007,6 +1064,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderTouchEnd() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
|
||||
@@ -1464,6 +1522,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
selectPreviewModel(model.id);
|
||||
saveLaunchDefaults();
|
||||
isModelDropdownOpen = false;
|
||||
modelDropdownSearch = '';
|
||||
}
|
||||
@@ -1497,7 +1556,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Pipeline'}
|
||||
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1508,7 +1567,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
Pipeline
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Tensor'}
|
||||
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1526,7 +1585,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxRing'}
|
||||
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1537,7 +1596,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
MLX Ring
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxIbv'}
|
||||
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
|
||||
@@ -28,7 +28,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
worker: Worker
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
master: Master | None
|
||||
@@ -62,15 +62,19 @@ class Node:
|
||||
else:
|
||||
api = None
|
||||
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -100,8 +104,9 @@ class Node:
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.worker.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
tg.start_soon(self.master.run)
|
||||
if self.api:
|
||||
@@ -209,6 +214,7 @@ class Args(CamelCaseModel):
|
||||
spawn_api: bool = False
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -246,6 +252,10 @@ class Args(CamelCaseModel):
|
||||
dest="api_port",
|
||||
default=52415,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
|
||||
)
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -29,6 +30,7 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
@@ -65,6 +67,28 @@ def place_instance(
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
if not command.model_meta.supports_tensor:
|
||||
raise ValueError(
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
|
||||
)
|
||||
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
|
||||
cycles_with_sufficient_memory = [
|
||||
cycle
|
||||
for cycle in cycles_with_sufficient_memory
|
||||
if command.model_meta.hidden_size % len(cycle) == 0
|
||||
]
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError(
|
||||
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
|
||||
)
|
||||
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
|
||||
"mlx-community/DeepSeek-V3.1-8bit"
|
||||
):
|
||||
raise ValueError(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
|
||||
@@ -385,13 +385,14 @@ def get_mlx_jaccl_coordinators(
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
|
||||
if ip:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
|
||||
@@ -50,7 +50,7 @@ def model_meta() -> ModelMetadata:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -53,6 +53,10 @@ class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShutdown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
@@ -70,6 +74,7 @@ RunnerStatus = (
|
||||
| RunnerWarmingUp
|
||||
| RunnerReady
|
||||
| RunnerRunning
|
||||
| RunnerShuttingDown
|
||||
| RunnerShutdown
|
||||
| RunnerFailed
|
||||
)
|
||||
|
||||
@@ -450,6 +450,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]
|
||||
|
||||
|
||||
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
# TODO: 'Smart' downloads are disabled because:
|
||||
# (i) We don't handle all kinds of files;
|
||||
# (ii) We don't have sticky sessions.
|
||||
# (iii) Tensor parallel requires all files.
|
||||
return ["*"]
|
||||
try:
|
||||
weight_map = await get_weight_map(str(shard.model_meta.model_id))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
|
||||
@@ -9,7 +9,7 @@ MAX_KV_SIZE: int | None = 3200
|
||||
KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = 8
|
||||
KV_CACHE_BITS: int | None = None
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -395,11 +395,5 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||
)
|
||||
kv_bytes = int(0.02 * model_bytes)
|
||||
target_cache = int(1.10 * (model_bytes + kv_bytes))
|
||||
target_cache = min(target_cache, max_rec_size)
|
||||
mx.set_cache_limit(target_cache)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(
|
||||
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
|
||||
)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
@@ -23,6 +23,7 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
@@ -83,7 +84,7 @@ class Worker:
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
@@ -128,6 +129,7 @@ class Worker:
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -200,11 +202,11 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard not in self.download_status:
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -217,7 +219,7 @@ class Worker:
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -349,7 +351,7 @@ class Worker:
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata] = status
|
||||
self.download_status[task.shard_metadata.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
@@ -363,7 +365,7 @@ class Worker:
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[shard] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
@@ -384,7 +386,7 @@ class Worker:
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
@@ -444,3 +446,40 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -34,7 +35,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def plan(
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
# DL_status is expected to be FRESH and so should not come from state
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
# gdls is not expected to be fresh
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
@@ -111,13 +111,14 @@ def _create_runner(
|
||||
|
||||
def _model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
not isinstance(
|
||||
download_status.get(runner.bound_instance.bound_shard, None),
|
||||
(DownloadOngoing, DownloadCompleted),
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
@@ -273,6 +274,12 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
|
||||
@@ -32,6 +32,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
@@ -187,13 +188,14 @@ def main(
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
break
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
@@ -208,9 +210,8 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,13 +14,23 @@ from anyio import (
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoading,
|
||||
RunnerRunning,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
@@ -39,10 +49,10 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
# err_path: str
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -77,7 +87,6 @@ class RunnerSupervisor:
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_event_sender=event_sender,
|
||||
# err_path=err_path,
|
||||
)
|
||||
|
||||
return self
|
||||
@@ -118,6 +127,10 @@ class RunnerSupervisor:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -138,6 +151,22 @@ class RunnerSupervisor:
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
):
|
||||
# If a task has just been completed, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
RunnerRunning,
|
||||
RunnerWarmingUp,
|
||||
RunnerLoading,
|
||||
RunnerConnecting,
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
@@ -9,9 +9,11 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
|
||||
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
||||
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
|
||||
|
||||
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
||||
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
||||
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
|
||||
|
||||
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
||||
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import BaseTask
|
||||
from exo.shared.types.tasks import BaseTask, TaskId
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
@@ -21,6 +19,7 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -7,7 +8,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerIdle,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
@@ -46,7 +46,7 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
download_status: dict[ModelId, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
@@ -94,7 +94,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
@@ -140,7 +140,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
@@ -192,7 +192,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
|
||||
@@ -12,8 +12,10 @@ from exo.worker.tests.constants import (
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
NODE_C,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
RUNNER_3_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import (
|
||||
FakeRunnerSupervisor,
|
||||
@@ -24,37 +26,39 @@ from exo.worker.tests.unittests.conftest import (
|
||||
|
||||
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
"""
|
||||
For non-final device_rank shards, StartWarmup should be emitted when all
|
||||
For non-zero device_rank shards, StartWarmup should be emitted when all
|
||||
shards in the instance are Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
RUNNER_3_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
@@ -150,9 +154,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
"""
|
||||
Rank-zero shard should not start warmup until all non-zero ranks are
|
||||
already WarmingUp.
|
||||
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
|
||||
For accepting ranks (device_rank != 0), StartWarmup should be
|
||||
emitted when all shards in the instance are Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
In a 2-node setup, rank 1 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -163,7 +167,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
# Rank 1 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
@@ -188,6 +192,23 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
@@ -280,9 +301,8 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
|
||||
|
||||
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
"""
|
||||
Connecting rank (device_rank == world_size - 1) should not start warmup
|
||||
Connecting rank (device_rank == 0) should not start warmup
|
||||
until all other ranks are already WarmingUp.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -295,13 +315,13 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
@@ -309,7 +329,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
|
||||
@@ -34,6 +34,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
@@ -199,6 +200,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user