mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
1 Commits
ciaran/han
...
feat/meta-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0825335c7 |
232
dashboard/src/lib/components/MetaInstanceCard.svelte
Normal file
232
dashboard/src/lib/components/MetaInstanceCard.svelte
Normal file
@@ -0,0 +1,232 @@
|
||||
<script lang="ts">
|
||||
import type {
|
||||
MetaInstance,
|
||||
MetaInstanceStatus,
|
||||
NodeInfo,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import {
|
||||
getMetaInstanceStatus,
|
||||
getMetaInstanceBackingNodes,
|
||||
topologyData,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
metaInstance: MetaInstance;
|
||||
onDelete?: (metaInstanceId: string) => void;
|
||||
}
|
||||
|
||||
let { metaInstance, onDelete }: Props = $props();
|
||||
|
||||
const status: MetaInstanceStatus = $derived(
|
||||
getMetaInstanceStatus(metaInstance),
|
||||
);
|
||||
const backingNodeIds: string[] = $derived(
|
||||
getMetaInstanceBackingNodes(metaInstance),
|
||||
);
|
||||
|
||||
const statusConfig = $derived.by(() => {
|
||||
switch (status) {
|
||||
case "active":
|
||||
return {
|
||||
label: "ACTIVE",
|
||||
dotClass: "bg-green-400",
|
||||
borderClass:
|
||||
"border-green-500/30 border-l-green-400",
|
||||
cornerClass: "border-green-500/50",
|
||||
glowClass: "shadow-[0_0_6px_rgba(74,222,128,0.4)]",
|
||||
animate: false,
|
||||
};
|
||||
case "provisioning":
|
||||
return {
|
||||
label: "PROVISIONING",
|
||||
dotClass: "bg-yellow-400",
|
||||
borderClass:
|
||||
"border-exo-yellow/30 border-l-yellow-400",
|
||||
cornerClass: "border-yellow-500/50",
|
||||
glowClass: "shadow-[0_0_6px_rgba(250,204,21,0.4)]",
|
||||
animate: true,
|
||||
};
|
||||
case "error":
|
||||
return {
|
||||
label: "ERROR",
|
||||
dotClass: "bg-red-400",
|
||||
borderClass: "border-red-500/30 border-l-red-400",
|
||||
cornerClass: "border-red-500/50",
|
||||
glowClass: "shadow-[0_0_6px_rgba(248,113,113,0.4)]",
|
||||
animate: false,
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
function getNodeName(nodeId: string): string {
|
||||
const topo = topologyData();
|
||||
if (!topo?.nodes) return nodeId.slice(0, 8);
|
||||
const node = topo.nodes[nodeId];
|
||||
return node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8);
|
||||
}
|
||||
|
||||
function formatModelId(modelId: string): string {
|
||||
// Show just the model name part after the org prefix
|
||||
const parts = modelId.split("/");
|
||||
return parts.length > 1 ? parts[parts.length - 1] : modelId;
|
||||
}
|
||||
|
||||
function handleDelete() {
|
||||
if (
|
||||
onDelete &&
|
||||
confirm(
|
||||
`Delete meta-instance for ${formatModelId(metaInstance.modelId)}?`,
|
||||
)
|
||||
) {
|
||||
onDelete(metaInstance.metaInstanceId);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
<!-- Corner accents -->
|
||||
<div
|
||||
class="absolute -top-px -left-px w-2 h-2 border-l border-t {statusConfig.cornerClass}"
|
||||
></div>
|
||||
<div
|
||||
class="absolute -top-px -right-px w-2 h-2 border-r border-t {statusConfig.cornerClass}"
|
||||
></div>
|
||||
<div
|
||||
class="absolute -bottom-px -left-px w-2 h-2 border-l border-b {statusConfig.cornerClass}"
|
||||
></div>
|
||||
<div
|
||||
class="absolute -bottom-px -right-px w-2 h-2 border-r border-b {statusConfig.cornerClass}"
|
||||
></div>
|
||||
|
||||
<div
|
||||
class="bg-exo-dark-gray/60 border border-l-2 {statusConfig.borderClass} p-3"
|
||||
>
|
||||
<!-- Header: Status + Delete -->
|
||||
<div class="flex justify-between items-start mb-2 pl-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div
|
||||
class="w-1.5 h-1.5 {statusConfig.dotClass} rounded-full {statusConfig.glowClass} {statusConfig.animate
|
||||
? 'animate-pulse'
|
||||
: ''}"
|
||||
></div>
|
||||
<span
|
||||
class="text-xs font-mono tracking-[0.15em] uppercase {status === 'active'
|
||||
? 'text-green-400'
|
||||
: status === 'error'
|
||||
? 'text-red-400'
|
||||
: 'text-yellow-400'}"
|
||||
>
|
||||
{statusConfig.label}
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={handleDelete}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Model Info -->
|
||||
<div class="pl-2 space-y-1">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">
|
||||
{metaInstance.modelId}
|
||||
</div>
|
||||
|
||||
<!-- Sharding + Runtime badges -->
|
||||
<div class="flex items-center gap-2">
|
||||
<span
|
||||
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
|
||||
>
|
||||
{metaInstance.sharding}
|
||||
</span>
|
||||
<span
|
||||
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
|
||||
>
|
||||
{metaInstance.instanceMeta}
|
||||
</span>
|
||||
{#if metaInstance.minNodes > 1}
|
||||
<span
|
||||
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
|
||||
>
|
||||
{metaInstance.minNodes}+ nodes
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Node Assignments (when active) -->
|
||||
{#if backingNodeIds.length > 0}
|
||||
<div class="flex items-center gap-1.5 mt-1">
|
||||
<svg
|
||||
class="w-3 h-3 text-green-400/70 flex-shrink-0"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M22 12h-4l-3 9L9 3l-3 9H2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-white/60 text-xs font-mono truncate">
|
||||
{backingNodeIds.map((id) => getNodeName(id)).join(", ")}
|
||||
</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Pinned nodes constraint -->
|
||||
{#if metaInstance.nodeIds && metaInstance.nodeIds.length > 0}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3 h-3 text-white/40 flex-shrink-0"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="11" width="18" height="11" rx="2" ry="2" />
|
||||
<path d="M7 11V7a5 5 0 0 1 10 0v4" />
|
||||
</svg>
|
||||
<span class="text-white/40 text-[11px] font-mono">
|
||||
Pinned: {metaInstance.nodeIds
|
||||
.map((id) => getNodeName(id))
|
||||
.join(", ")}
|
||||
</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Error details -->
|
||||
{#if metaInstance.placementError}
|
||||
<div
|
||||
class="mt-1.5 p-2 bg-red-500/5 border border-red-500/15 rounded-sm"
|
||||
>
|
||||
<div class="text-red-400 text-[11px] font-mono leading-relaxed">
|
||||
{metaInstance.placementError}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Retry counter -->
|
||||
{#if metaInstance.consecutiveFailures > 0}
|
||||
<div class="flex items-center gap-1.5 mt-1">
|
||||
<svg
|
||||
class="w-3 h-3 text-yellow-500/60 flex-shrink-0"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<polyline points="23 4 23 10 17 10" />
|
||||
<path d="M20.49 15a9 9 0 1 1-2.12-9.36L23 10" />
|
||||
</svg>
|
||||
<span class="text-yellow-500/60 text-[11px] font-mono">
|
||||
{metaInstance.consecutiveFailures} consecutive
|
||||
failure{metaInstance.consecutiveFailures !== 1 ? "s" : ""}
|
||||
</span>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -11,4 +11,5 @@ export { default as FamilySidebar } from "./FamilySidebar.svelte";
|
||||
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
|
||||
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
|
||||
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
|
||||
export { default as MetaInstanceCard } from "./MetaInstanceCard.svelte";
|
||||
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";
|
||||
|
||||
@@ -72,8 +72,23 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
metaInstanceId?: string | null;
|
||||
}
|
||||
|
||||
export interface MetaInstance {
|
||||
metaInstanceId: string;
|
||||
modelId: string;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instanceMeta: "MlxRing" | "MlxJaccl";
|
||||
minNodes: number;
|
||||
nodeIds: string[] | null;
|
||||
placementError: string | null;
|
||||
consecutiveFailures: number;
|
||||
lastFailureError: string | null;
|
||||
}
|
||||
|
||||
export type MetaInstanceStatus = "active" | "provisioning" | "error";
|
||||
|
||||
// Granular node state types from the new state structure
|
||||
interface RawNodeIdentity {
|
||||
modelId?: string;
|
||||
@@ -223,6 +238,7 @@ interface RawStateResponse {
|
||||
MlxJacclInstance?: Instance;
|
||||
}
|
||||
>;
|
||||
metaInstances?: Record<string, MetaInstance>;
|
||||
runners?: Record<string, unknown>;
|
||||
downloads?: Record<string, unknown[]>;
|
||||
// New granular node state fields
|
||||
@@ -533,6 +549,7 @@ class AppStore {
|
||||
// Topology state
|
||||
topologyData = $state<TopologyData | null>(null);
|
||||
instances = $state<Record<string, unknown>>({});
|
||||
metaInstances = $state<Record<string, MetaInstance>>({});
|
||||
runners = $state<Record<string, unknown>>({});
|
||||
downloads = $state<Record<string, unknown[]>>({});
|
||||
nodeDisk = $state<
|
||||
@@ -1268,6 +1285,9 @@ class AppStore {
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
}
|
||||
if (data.metaInstances) {
|
||||
this.metaInstances = data.metaInstances;
|
||||
}
|
||||
if (data.runners) {
|
||||
this.runners = data.runners;
|
||||
}
|
||||
@@ -1293,6 +1313,79 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
async createMetaInstance(
|
||||
modelId: string,
|
||||
sharding: "Pipeline" | "Tensor" = "Pipeline",
|
||||
instanceMeta: "MlxRing" | "MlxJaccl" = "MlxRing",
|
||||
minNodes: number = 1,
|
||||
nodeIds: string[] | null = null,
|
||||
) {
|
||||
try {
|
||||
const response = await fetch("/meta_instance", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model_id: modelId,
|
||||
sharding,
|
||||
instance_meta: instanceMeta,
|
||||
min_nodes: minNodes,
|
||||
node_ids: nodeIds,
|
||||
}),
|
||||
});
|
||||
if (!response.ok) {
|
||||
console.error("Failed to create meta-instance:", response.status);
|
||||
}
|
||||
await this.fetchState();
|
||||
} catch (error) {
|
||||
console.error("Error creating meta-instance:", error);
|
||||
}
|
||||
}
|
||||
|
||||
async deleteMetaInstance(metaInstanceId: string) {
|
||||
try {
|
||||
const response = await fetch(`/meta_instance/${metaInstanceId}`, {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
if (!response.ok) {
|
||||
console.error("Failed to delete meta-instance:", response.status);
|
||||
}
|
||||
await this.fetchState();
|
||||
} catch (error) {
|
||||
console.error("Error deleting meta-instance:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getMetaInstanceStatus(
|
||||
metaInstance: MetaInstance,
|
||||
): MetaInstanceStatus {
|
||||
// Check if any running instance is bound to this meta-instance
|
||||
for (const instanceWrapper of Object.values(this.instances)) {
|
||||
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
|
||||
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
|
||||
if (keys.length !== 1) continue;
|
||||
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]];
|
||||
if (inner && typeof inner === "object" && (inner as Instance).metaInstanceId === metaInstance.metaInstanceId) {
|
||||
return "active";
|
||||
}
|
||||
}
|
||||
if (metaInstance.placementError) return "error";
|
||||
return "provisioning";
|
||||
}
|
||||
|
||||
getMetaInstanceBackingNodes(metaInstance: MetaInstance): string[] {
|
||||
for (const instanceWrapper of Object.values(this.instances)) {
|
||||
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
|
||||
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
|
||||
if (keys.length !== 1) continue;
|
||||
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]] as Instance;
|
||||
if (inner?.metaInstanceId === metaInstance.metaInstanceId && inner?.shardAssignments?.nodeToRunner) {
|
||||
return Object.keys(inner.shardAssignments.nodeToRunner);
|
||||
}
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
async fetchPlacementPreviews(modelId: string, showLoading = true) {
|
||||
if (!modelId) return;
|
||||
|
||||
@@ -3154,6 +3247,7 @@ export const totalTokens = () => appStore.totalTokens;
|
||||
export const prefillProgress = () => appStore.prefillProgress;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const metaInstances = () => appStore.metaInstances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
@@ -3242,6 +3336,21 @@ export const setChatSidebarVisible = (visible: boolean) =>
|
||||
appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
// Meta-instance actions
|
||||
export const createMetaInstance = (
|
||||
modelId: string,
|
||||
sharding?: "Pipeline" | "Tensor",
|
||||
instanceMeta?: "MlxRing" | "MlxJaccl",
|
||||
minNodes?: number,
|
||||
nodeIds?: string[] | null,
|
||||
) => appStore.createMetaInstance(modelId, sharding, instanceMeta, minNodes, nodeIds);
|
||||
export const deleteMetaInstance = (metaInstanceId: string) =>
|
||||
appStore.deleteMetaInstance(metaInstanceId);
|
||||
export const getMetaInstanceStatus = (metaInstance: MetaInstance) =>
|
||||
appStore.getMetaInstanceStatus(metaInstance);
|
||||
export const getMetaInstanceBackingNodes = (metaInstance: MetaInstance) =>
|
||||
appStore.getMetaInstanceBackingNodes(metaInstance);
|
||||
|
||||
// Node identities (for OS version mismatch detection)
|
||||
export const nodeIdentities = () => appStore.nodeIdentities;
|
||||
|
||||
|
||||
@@ -47,10 +47,14 @@
|
||||
thunderboltBridgeCycles,
|
||||
nodeThunderboltBridge,
|
||||
nodeIdentities,
|
||||
metaInstances,
|
||||
deleteMetaInstance,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview,
|
||||
type MetaInstance,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
import MetaInstanceCard from "$lib/components/MetaInstanceCard.svelte";
|
||||
import { fade, fly } from "svelte/transition";
|
||||
import { cubicInOut } from "svelte/easing";
|
||||
import { onMount } from "svelte";
|
||||
@@ -67,6 +71,8 @@
|
||||
const loadingPreviews = $derived(isLoadingPreviews());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const metaInstanceData = $derived(metaInstances());
|
||||
const metaInstanceCount = $derived(Object.keys(metaInstanceData).length);
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
@@ -3056,6 +3062,39 @@
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Meta-Instances Panel -->
|
||||
{#if metaInstanceCount > 0}
|
||||
<div class="p-4 flex-shrink-0 border-t border-exo-yellow/10">
|
||||
<!-- Panel Header -->
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div
|
||||
class="w-2 h-2 border border-purple-400/60 rotate-45"
|
||||
></div>
|
||||
<h3
|
||||
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
|
||||
>
|
||||
Meta-Instances
|
||||
</h3>
|
||||
<div
|
||||
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
|
||||
></div>
|
||||
<span class="text-[10px] text-white/40 font-mono"
|
||||
>{metaInstanceCount}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="space-y-3">
|
||||
{#each Object.entries(metaInstanceData) as [id, mi]}
|
||||
<MetaInstanceCard
|
||||
metaInstance={mi}
|
||||
onDelete={(metaInstanceId) =>
|
||||
deleteMetaInstance(metaInstanceId)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Models Panel - Scrollable -->
|
||||
<div class="p-4 flex-1 overflow-y-auto">
|
||||
<!-- Panel Header -->
|
||||
@@ -3878,6 +3917,34 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Meta-Instances Section (chat sidebar) -->
|
||||
{#if metaInstanceCount > 0}
|
||||
<div class="p-4 border-t border-exo-yellow/10">
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div
|
||||
class="w-2 h-2 border border-purple-400/60 rotate-45"
|
||||
></div>
|
||||
<h3
|
||||
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
|
||||
>
|
||||
Meta-Instances
|
||||
</h3>
|
||||
<div
|
||||
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
|
||||
></div>
|
||||
</div>
|
||||
<div class="space-y-3">
|
||||
{#each Object.entries(metaInstanceData) as [id, mi]}
|
||||
<MetaInstanceCard
|
||||
metaInstance={mi}
|
||||
onDelete={(metaInstanceId) =>
|
||||
deleteMetaInstance(metaInstanceId)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</aside>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -23,8 +23,6 @@ use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
const MAX_PING_FAILURES: u32 = 3;
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
@@ -33,8 +31,8 @@ mod managed {
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const PING_INTERVAL: Duration = Duration::from_secs(5);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
@@ -111,9 +109,6 @@ pub struct Behaviour {
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
|
||||
// track consecutive ping failures per connection for N-strike tolerance
|
||||
ping_failures: HashMap<ConnectionId, u32>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
@@ -123,7 +118,6 @@ impl Behaviour {
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
ping_failures: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -314,7 +308,6 @@ impl NetworkBehaviour for Behaviour {
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
self.ping_failures.remove(&connection_id);
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
@@ -344,41 +337,10 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => disconnect after N consecutive failures
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
match &e.result {
|
||||
Err(err) => {
|
||||
let count = self.ping_failures.entry(e.connection).or_insert(0);
|
||||
*count += 1;
|
||||
log::warn!(
|
||||
"Ping failed for peer {:?} (connection {:?}): {:?} — failure {}/{}",
|
||||
e.peer,
|
||||
e.connection,
|
||||
err,
|
||||
count,
|
||||
MAX_PING_FAILURES
|
||||
);
|
||||
if *count >= MAX_PING_FAILURES {
|
||||
log::warn!(
|
||||
"Closing connection to peer {:?} after {} consecutive ping failures",
|
||||
e.peer,
|
||||
MAX_PING_FAILURES
|
||||
);
|
||||
self.ping_failures.remove(&e.connection);
|
||||
self.close_connection(e.peer, e.connection);
|
||||
}
|
||||
}
|
||||
Ok(rtt) => {
|
||||
// Reset failure counter on successful ping
|
||||
if self.ping_failures.remove(&e.connection).is_some() {
|
||||
log::debug!(
|
||||
"Ping recovered for peer {:?} (rtt={:?}), reset failure counter",
|
||||
e.peer,
|
||||
rtt
|
||||
);
|
||||
}
|
||||
log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt);
|
||||
}
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
@@ -21,9 +22,10 @@ from exo.shared.types.commands import (
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -34,27 +36,33 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
event_sender: Sender[Event]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool = False
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
# Per-model throttle for download progress events
|
||||
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
if self.offline:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
@@ -109,6 +117,7 @@ class DownloadCoordinator:
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
if not self.offline:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
@@ -288,6 +297,21 @@ class DownloadCoordinator:
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -57,8 +57,23 @@ class Node:
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
@@ -85,20 +100,6 @@ class Node:
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# DownloadCoordinator sends events through the Worker's event channel
|
||||
# so they get the same index sequence and retry mechanism
|
||||
if not args.no_downloads:
|
||||
assert worker is not None, "DownloadCoordinator requires a Worker"
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
event_sender=worker.event_sender.clone(),
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -213,6 +214,20 @@ class Node:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
@@ -230,19 +245,6 @@ class Node:
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
assert self.worker is not None
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_sender=self.worker.event_sender.clone(),
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.api:
|
||||
self.api.reset(result.session_id, result.won_clock)
|
||||
else:
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from itertools import count
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.master.main import Master
|
||||
from exo.master.tests.conftest import create_node_memory
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
def _complete_progress(shard: ShardMetadata) -> RepoDownloadProgress:
|
||||
return RepoDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_revision="test",
|
||||
shard=shard,
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
|
||||
|
||||
class _TestShardDownloader(ShardDownloader):
|
||||
"""Shard downloader that reports every shard as already complete."""
|
||||
|
||||
async def ensure_shard(
|
||||
self, shard: ShardMetadata, config_only: bool = False
|
||||
) -> Path:
|
||||
return Path("/tmp/test_shard")
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def get_shard_download_status(
|
||||
self,
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
# Yield nothing — no pre-existing downloads
|
||||
return
|
||||
yield # make this an async generator
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
return _complete_progress(shard)
|
||||
|
||||
|
||||
def _make_heartbeat(node_id: NodeId) -> NodeGatheredInfo:
|
||||
return NodeGatheredInfo(
|
||||
node_id=node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=create_node_memory(500),
|
||||
)
|
||||
|
||||
|
||||
class _PartitionSwitch:
|
||||
"""Mutable boolean flag shared with the partition proxy coroutine."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.connected = True
|
||||
|
||||
|
||||
async def _partition_proxy(
|
||||
source: Receiver[ForwarderEvent],
|
||||
dest: Sender[ForwarderEvent],
|
||||
switch: _PartitionSwitch,
|
||||
) -> None:
|
||||
"""Forward events when ``switch.connected`` is True; drop otherwise."""
|
||||
with source as events:
|
||||
async for event in events:
|
||||
if switch.connected:
|
||||
await dest.send(event)
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], object], *, timeout: float = 5.0, poll: float = 0.02
|
||||
) -> None:
|
||||
"""Poll *predicate* until truthy, raising on timeout."""
|
||||
with anyio.fail_after(timeout):
|
||||
while not predicate():
|
||||
await anyio.sleep(poll)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1 – same master: Worker + DC retry recovers lost events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partition_recovery_same_master() -> None:
|
||||
"""Worker's out_for_delivery retry fills the Master's buffer gap after a
|
||||
partition heals, even when DownloadCoordinator events are interleaved."""
|
||||
|
||||
master_node = NodeId("master-node")
|
||||
worker_node = NodeId("worker-node")
|
||||
session = SessionId(master_node_id=master_node, election_clock=1)
|
||||
switch = _PartitionSwitch()
|
||||
|
||||
# --- channels --------------------------------------------------------
|
||||
# Worker → proxy → Master (local events)
|
||||
worker_local_send, proxy_local_recv = channel[ForwarderEvent]()
|
||||
proxy_local_send, master_local_recv = channel[ForwarderEvent]()
|
||||
|
||||
# Master → proxy → Worker (global events)
|
||||
master_global_send, proxy_global_recv = channel[ForwarderEvent]()
|
||||
proxy_global_send, worker_global_recv = channel[ForwarderEvent]()
|
||||
|
||||
# Commands (required by constructors)
|
||||
cmd_send, cmd_recv = channel[ForwarderCommand]()
|
||||
dl_cmd_send, dl_cmd_recv = channel[ForwarderDownloadCommand]()
|
||||
|
||||
# --- components ------------------------------------------------------
|
||||
worker = Worker(
|
||||
worker_node,
|
||||
session,
|
||||
global_event_receiver=worker_global_recv,
|
||||
local_event_sender=worker_local_send,
|
||||
command_sender=cmd_send.clone(),
|
||||
download_command_sender=dl_cmd_send.clone(),
|
||||
event_index_counter=count(),
|
||||
)
|
||||
|
||||
dc = DownloadCoordinator(
|
||||
node_id=worker_node,
|
||||
shard_downloader=_TestShardDownloader(),
|
||||
download_command_receiver=dl_cmd_recv,
|
||||
event_sender=worker.event_sender.clone(),
|
||||
offline=True,
|
||||
)
|
||||
|
||||
master = Master(
|
||||
master_node,
|
||||
session,
|
||||
command_receiver=cmd_recv,
|
||||
local_event_receiver=master_local_recv,
|
||||
global_event_sender=master_global_send,
|
||||
download_command_sender=dl_cmd_send.clone(),
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(_partition_proxy, proxy_local_recv, proxy_local_send, switch)
|
||||
tg.start_soon(_partition_proxy, proxy_global_recv, proxy_global_send, switch)
|
||||
tg.start_soon(master.run)
|
||||
tg.start_soon(dc.run)
|
||||
tg.start_soon(worker.run)
|
||||
|
||||
# 1. Pre-partition: heartbeat reaches master
|
||||
await worker.event_sender.send(_make_heartbeat(worker_node))
|
||||
await _wait_until(lambda: worker_node in master.state.last_seen)
|
||||
initial_last_seen = master.state.last_seen[worker_node]
|
||||
|
||||
# 2. Partition — proxy drops everything
|
||||
switch.connected = False
|
||||
|
||||
# Worker heartbeat during partition — lost at proxy, kept in
|
||||
# out_for_delivery.
|
||||
await worker.event_sender.send(_make_heartbeat(worker_node))
|
||||
|
||||
# Trigger a download via DC's command channel. NoopShardDownloader
|
||||
# returns status="complete" for any shard, so _start_download emits
|
||||
# NodeDownloadProgress(DownloadPending) then
|
||||
# NodeDownloadProgress(DownloadCompleted) through worker.event_sender.
|
||||
# These go through _forward_events → proxy (dropped) → out_for_delivery.
|
||||
# Use a unique model ID so the DC doesn't skip it as already-completed
|
||||
# (it pre-emits progress for the default "noop" model at startup).
|
||||
test_shard = PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_id=ModelId("test-partition-model"),
|
||||
n_layers=1,
|
||||
storage_size=Memory.from_bytes(0),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
await dl_cmd_send.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=worker_node,
|
||||
command=StartDownload(
|
||||
target_node_id=worker_node,
|
||||
shard_metadata=test_shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for DC events to flow through worker's _forward_events
|
||||
# (poll instead of sleeping a fixed duration to avoid flakiness on slow CI)
|
||||
await _wait_until(lambda: len(worker.out_for_delivery) >= 3)
|
||||
|
||||
# Verify at least one is a download progress event
|
||||
has_download_event = any(
|
||||
isinstance(fe.event, NodeDownloadProgress)
|
||||
for fe in worker.out_for_delivery.values()
|
||||
)
|
||||
assert has_download_event, (
|
||||
"out_for_delivery should contain DC-originated download events"
|
||||
)
|
||||
|
||||
# 3. Heal partition
|
||||
switch.connected = True
|
||||
|
||||
# Worker's _resend_out_for_delivery runs every ~1-2s.
|
||||
await _wait_until(
|
||||
lambda: master.state.last_seen.get(worker_node, initial_last_seen)
|
||||
> initial_last_seen,
|
||||
timeout=8.0,
|
||||
)
|
||||
|
||||
# 4. All events recovered — both worker heartbeats and DC download
|
||||
# progress events were retried and accepted by master.
|
||||
await _wait_until(lambda: len(worker.out_for_delivery) == 0, timeout=8.0)
|
||||
|
||||
# Master state reflects the download
|
||||
assert worker_node in master.state.downloads
|
||||
|
||||
await master.shutdown()
|
||||
worker.shutdown()
|
||||
dc.shutdown()
|
||||
Reference in New Issue
Block a user