Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
b0825335c7 feat: add meta-instance dashboard UI components
Add MetaInstanceCard component and integrate meta-instance display into
both welcome and chat sidebars. Includes store types, state management,
CRUD operations, and status derivation (active/provisioning/error).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 11:47:23 -08:00
24 changed files with 548 additions and 1016 deletions

View File

@@ -0,0 +1,232 @@
<script lang="ts">
import type {
MetaInstance,
MetaInstanceStatus,
NodeInfo,
} from "$lib/stores/app.svelte";
import {
getMetaInstanceStatus,
getMetaInstanceBackingNodes,
topologyData,
} from "$lib/stores/app.svelte";
interface Props {
metaInstance: MetaInstance;
onDelete?: (metaInstanceId: string) => void;
}
let { metaInstance, onDelete }: Props = $props();
const status: MetaInstanceStatus = $derived(
getMetaInstanceStatus(metaInstance),
);
const backingNodeIds: string[] = $derived(
getMetaInstanceBackingNodes(metaInstance),
);
const statusConfig = $derived.by(() => {
switch (status) {
case "active":
return {
label: "ACTIVE",
dotClass: "bg-green-400",
borderClass:
"border-green-500/30 border-l-green-400",
cornerClass: "border-green-500/50",
glowClass: "shadow-[0_0_6px_rgba(74,222,128,0.4)]",
animate: false,
};
case "provisioning":
return {
label: "PROVISIONING",
dotClass: "bg-yellow-400",
borderClass:
"border-exo-yellow/30 border-l-yellow-400",
cornerClass: "border-yellow-500/50",
glowClass: "shadow-[0_0_6px_rgba(250,204,21,0.4)]",
animate: true,
};
case "error":
return {
label: "ERROR",
dotClass: "bg-red-400",
borderClass: "border-red-500/30 border-l-red-400",
cornerClass: "border-red-500/50",
glowClass: "shadow-[0_0_6px_rgba(248,113,113,0.4)]",
animate: false,
};
}
});
function getNodeName(nodeId: string): string {
const topo = topologyData();
if (!topo?.nodes) return nodeId.slice(0, 8);
const node = topo.nodes[nodeId];
return node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8);
}
function formatModelId(modelId: string): string {
// Show just the model name part after the org prefix
const parts = modelId.split("/");
return parts.length > 1 ? parts[parts.length - 1] : modelId;
}
function handleDelete() {
if (
onDelete &&
confirm(
`Delete meta-instance for ${formatModelId(metaInstance.modelId)}?`,
)
) {
onDelete(metaInstance.metaInstanceId);
}
}
</script>
<div class="relative group">
<!-- Corner accents -->
<div
class="absolute -top-px -left-px w-2 h-2 border-l border-t {statusConfig.cornerClass}"
></div>
<div
class="absolute -top-px -right-px w-2 h-2 border-r border-t {statusConfig.cornerClass}"
></div>
<div
class="absolute -bottom-px -left-px w-2 h-2 border-l border-b {statusConfig.cornerClass}"
></div>
<div
class="absolute -bottom-px -right-px w-2 h-2 border-r border-b {statusConfig.cornerClass}"
></div>
<div
class="bg-exo-dark-gray/60 border border-l-2 {statusConfig.borderClass} p-3"
>
<!-- Header: Status + Delete -->
<div class="flex justify-between items-start mb-2 pl-2">
<div class="flex items-center gap-2">
<div
class="w-1.5 h-1.5 {statusConfig.dotClass} rounded-full {statusConfig.glowClass} {statusConfig.animate
? 'animate-pulse'
: ''}"
></div>
<span
class="text-xs font-mono tracking-[0.15em] uppercase {status === 'active'
? 'text-green-400'
: status === 'error'
? 'text-red-400'
: 'text-yellow-400'}"
>
{statusConfig.label}
</span>
</div>
<button
onclick={handleDelete}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<!-- Model Info -->
<div class="pl-2 space-y-1">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">
{metaInstance.modelId}
</div>
<!-- Sharding + Runtime badges -->
<div class="flex items-center gap-2">
<span
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
>
{metaInstance.sharding}
</span>
<span
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
>
{metaInstance.instanceMeta}
</span>
{#if metaInstance.minNodes > 1}
<span
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
>
{metaInstance.minNodes}+ nodes
</span>
{/if}
</div>
<!-- Node Assignments (when active) -->
{#if backingNodeIds.length > 0}
<div class="flex items-center gap-1.5 mt-1">
<svg
class="w-3 h-3 text-green-400/70 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M22 12h-4l-3 9L9 3l-3 9H2"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
<span class="text-white/60 text-xs font-mono truncate">
{backingNodeIds.map((id) => getNodeName(id)).join(", ")}
</span>
</div>
{/if}
<!-- Pinned nodes constraint -->
{#if metaInstance.nodeIds && metaInstance.nodeIds.length > 0}
<div class="flex items-center gap-1.5">
<svg
class="w-3 h-3 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="11" width="18" height="11" rx="2" ry="2" />
<path d="M7 11V7a5 5 0 0 1 10 0v4" />
</svg>
<span class="text-white/40 text-[11px] font-mono">
Pinned: {metaInstance.nodeIds
.map((id) => getNodeName(id))
.join(", ")}
</span>
</div>
{/if}
<!-- Error details -->
{#if metaInstance.placementError}
<div
class="mt-1.5 p-2 bg-red-500/5 border border-red-500/15 rounded-sm"
>
<div class="text-red-400 text-[11px] font-mono leading-relaxed">
{metaInstance.placementError}
</div>
</div>
{/if}
<!-- Retry counter -->
{#if metaInstance.consecutiveFailures > 0}
<div class="flex items-center gap-1.5 mt-1">
<svg
class="w-3 h-3 text-yellow-500/60 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<polyline points="23 4 23 10 17 10" />
<path d="M20.49 15a9 9 0 1 1-2.12-9.36L23 10" />
</svg>
<span class="text-yellow-500/60 text-[11px] font-mono">
{metaInstance.consecutiveFailures} consecutive
failure{metaInstance.consecutiveFailures !== 1 ? "s" : ""}
</span>
</div>
{/if}
</div>
</div>
</div>

View File

@@ -11,4 +11,5 @@ export { default as FamilySidebar } from "./FamilySidebar.svelte";
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
export { default as MetaInstanceCard } from "./MetaInstanceCard.svelte";
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";

View File

@@ -72,8 +72,23 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
metaInstanceId?: string | null;
}
export interface MetaInstance {
metaInstanceId: string;
modelId: string;
sharding: "Pipeline" | "Tensor";
instanceMeta: "MlxRing" | "MlxJaccl";
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
}
export type MetaInstanceStatus = "active" | "provisioning" | "error";
// Granular node state types from the new state structure
interface RawNodeIdentity {
modelId?: string;
@@ -223,6 +238,7 @@ interface RawStateResponse {
MlxJacclInstance?: Instance;
}
>;
metaInstances?: Record<string, MetaInstance>;
runners?: Record<string, unknown>;
downloads?: Record<string, unknown[]>;
// New granular node state fields
@@ -250,11 +266,6 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// Disk usage per node
nodeDisk?: Record<
string,
{ total: { inBytes: number }; available: { inBytes: number } }
>;
}
export interface MessageAttachment {
@@ -538,6 +549,7 @@ class AppStore {
// Topology state
topologyData = $state<TopologyData | null>(null);
instances = $state<Record<string, unknown>>({});
metaInstances = $state<Record<string, MetaInstance>>({});
runners = $state<Record<string, unknown>>({});
downloads = $state<Record<string, unknown[]>>({});
nodeDisk = $state<
@@ -1273,6 +1285,9 @@ class AppStore {
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
if (data.metaInstances) {
this.metaInstances = data.metaInstances;
}
if (data.runners) {
this.runners = data.runners;
}
@@ -1298,6 +1313,79 @@ class AppStore {
}
}
async createMetaInstance(
modelId: string,
sharding: "Pipeline" | "Tensor" = "Pipeline",
instanceMeta: "MlxRing" | "MlxJaccl" = "MlxRing",
minNodes: number = 1,
nodeIds: string[] | null = null,
) {
try {
const response = await fetch("/meta_instance", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model_id: modelId,
sharding,
instance_meta: instanceMeta,
min_nodes: minNodes,
node_ids: nodeIds,
}),
});
if (!response.ok) {
console.error("Failed to create meta-instance:", response.status);
}
await this.fetchState();
} catch (error) {
console.error("Error creating meta-instance:", error);
}
}
async deleteMetaInstance(metaInstanceId: string) {
try {
const response = await fetch(`/meta_instance/${metaInstanceId}`, {
method: "DELETE",
headers: { "Content-Type": "application/json" },
});
if (!response.ok) {
console.error("Failed to delete meta-instance:", response.status);
}
await this.fetchState();
} catch (error) {
console.error("Error deleting meta-instance:", error);
}
}
getMetaInstanceStatus(
metaInstance: MetaInstance,
): MetaInstanceStatus {
// Check if any running instance is bound to this meta-instance
for (const instanceWrapper of Object.values(this.instances)) {
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
if (keys.length !== 1) continue;
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]];
if (inner && typeof inner === "object" && (inner as Instance).metaInstanceId === metaInstance.metaInstanceId) {
return "active";
}
}
if (metaInstance.placementError) return "error";
return "provisioning";
}
getMetaInstanceBackingNodes(metaInstance: MetaInstance): string[] {
for (const instanceWrapper of Object.values(this.instances)) {
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
if (keys.length !== 1) continue;
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]] as Instance;
if (inner?.metaInstanceId === metaInstance.metaInstanceId && inner?.shardAssignments?.nodeToRunner) {
return Object.keys(inner.shardAssignments.nodeToRunner);
}
}
return [];
}
async fetchPlacementPreviews(modelId: string, showLoading = true) {
if (!modelId) return;
@@ -3159,6 +3247,7 @@ export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;
@@ -3247,6 +3336,21 @@ export const setChatSidebarVisible = (visible: boolean) =>
appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();
// Meta-instance actions
export const createMetaInstance = (
modelId: string,
sharding?: "Pipeline" | "Tensor",
instanceMeta?: "MlxRing" | "MlxJaccl",
minNodes?: number,
nodeIds?: string[] | null,
) => appStore.createMetaInstance(modelId, sharding, instanceMeta, minNodes, nodeIds);
export const deleteMetaInstance = (metaInstanceId: string) =>
appStore.deleteMetaInstance(metaInstanceId);
export const getMetaInstanceStatus = (metaInstance: MetaInstance) =>
appStore.getMetaInstanceStatus(metaInstance);
export const getMetaInstanceBackingNodes = (metaInstance: MetaInstance) =>
appStore.getMetaInstanceBackingNodes(metaInstance);
// Node identities (for OS version mismatch detection)
export const nodeIdentities = () => appStore.nodeIdentities;

View File

@@ -47,10 +47,14 @@
thunderboltBridgeCycles,
nodeThunderboltBridge,
nodeIdentities,
metaInstances,
deleteMetaInstance,
type DownloadProgress,
type PlacementPreview,
type MetaInstance,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
import MetaInstanceCard from "$lib/components/MetaInstanceCard.svelte";
import { fade, fly } from "svelte/transition";
import { cubicInOut } from "svelte/easing";
import { onMount } from "svelte";
@@ -67,6 +71,8 @@
const loadingPreviews = $derived(isLoadingPreviews());
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const metaInstanceData = $derived(metaInstances());
const metaInstanceCount = $derived(Object.keys(metaInstanceData).length);
const sidebarVisible = $derived(chatSidebarVisible());
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
const tbBridgeData = $derived(nodeThunderboltBridge());
@@ -858,8 +864,10 @@
if (!progress || typeof progress !== "object") return null;
const prog = progress as Record<string, unknown>;
const totalBytes = getBytes(prog.total);
const downloadedBytes = getBytes(prog.downloaded);
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const speed = (prog.speed as number) ?? 0;
const completedFiles =
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
@@ -872,8 +880,8 @@
for (const [fileName, fileData] of Object.entries(filesObj)) {
if (!fileData || typeof fileData !== "object") continue;
const fd = fileData as Record<string, unknown>;
const fTotal = getBytes(fd.total);
const fDownloaded = getBytes(fd.downloaded);
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
files.push({
name: fileName,
totalBytes: fTotal,
@@ -1262,6 +1270,7 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;
@@ -3053,6 +3062,39 @@
</div>
{/if}
<!-- Meta-Instances Panel -->
{#if metaInstanceCount > 0}
<div class="p-4 flex-shrink-0 border-t border-exo-yellow/10">
<!-- Panel Header -->
<div class="flex items-center gap-2 mb-4">
<div
class="w-2 h-2 border border-purple-400/60 rotate-45"
></div>
<h3
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
>
Meta-Instances
</h3>
<div
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
></div>
<span class="text-[10px] text-white/40 font-mono"
>{metaInstanceCount}</span
>
</div>
<div class="space-y-3">
{#each Object.entries(metaInstanceData) as [id, mi]}
<MetaInstanceCard
metaInstance={mi}
onDelete={(metaInstanceId) =>
deleteMetaInstance(metaInstanceId)}
/>
{/each}
</div>
</div>
{/if}
<!-- Models Panel - Scrollable -->
<div class="p-4 flex-1 overflow-y-auto">
<!-- Panel Header -->
@@ -3875,6 +3917,34 @@
</div>
</div>
{/if}
<!-- Meta-Instances Section (chat sidebar) -->
{#if metaInstanceCount > 0}
<div class="p-4 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-4">
<div
class="w-2 h-2 border border-purple-400/60 rotate-45"
></div>
<h3
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
>
Meta-Instances
</h3>
<div
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
></div>
</div>
<div class="space-y-3">
{#each Object.entries(metaInstanceData) as [id, mi]}
<MetaInstanceCard
metaInstance={mi}
onDelete={(metaInstanceId) =>
deleteMetaInstance(metaInstanceId)}
/>
{/each}
</div>
</div>
{/if}
</aside>
{/if}
</div>

View File

@@ -74,6 +74,7 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;
@@ -230,14 +231,23 @@
undefined;
let cell: CellStatus;
if (tag === "DownloadCompleted") {
const totalBytes = getBytes(payload.total);
const totalBytes = getBytes(
payload.total_bytes ?? payload.totalBytes,
);
cell = { kind: "completed", totalBytes, modelDirectory };
} else if (tag === "DownloadOngoing") {
const rawProgress =
payload.download_progress ?? payload.downloadProgress ?? {};
const prog = rawProgress as Record<string, unknown>;
const totalBytes = getBytes(prog.total ?? payload.total);
const downloadedBytes = getBytes(prog.downloaded);
const totalBytes = getBytes(
prog.total_bytes ??
prog.totalBytes ??
payload.total_bytes ??
payload.totalBytes,
);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const speed = (prog.speed as number) ?? 0;
const etaMs =
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;

View File

@@ -80,7 +80,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total=progress.total,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -203,7 +203,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total=initial_progress.total,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -332,13 +332,13 @@ class DownloadCoordinator:
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total=progress.total,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_this_session.in_bytes == 0:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
repo_file_download_progress: RepoFileDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
downloaded=repo_file_download_progress.downloaded,
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
total=repo_file_download_progress.total,
downloaded_bytes=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
total_bytes=repo_file_download_progress.total,
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1,
speed=repo_file_download_progress.speed,
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
repo_download_progress: RepoDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
total=repo_download_progress.total,
downloaded=repo_download_progress.downloaded,
downloaded_this_session=repo_download_progress.downloaded_this_session,
total_bytes=repo_download_progress.total_bytes,
downloaded_bytes=repo_download_progress.downloaded_bytes,
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
@@ -578,20 +578,19 @@ def calculate_repo_progress(
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
) -> RepoDownloadProgress:
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
all_downloaded = sum(
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum(
(p.downloaded.in_bytes for p in file_progress.values()), 0
)
all_downloaded_this_session = sum(
(p.downloaded_this_session for p in file_progress.values()),
Memory.from_bytes(0),
all_downloaded_bytes_this_session = sum(
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
)
elapsed_time = time.time() - all_start_time
all_speed = (
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
)
all_eta = (
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
if all_speed > 0
else timedelta(seconds=0)
)
@@ -610,9 +609,11 @@ def calculate_repo_progress(
[p for p in file_progress.values() if p.downloaded == p.total]
),
total_files=len(file_progress),
downloaded=all_downloaded,
downloaded_this_session=all_downloaded_this_session,
total=all_total,
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(
all_downloaded_bytes_this_session
),
total_bytes=Memory.from_bytes(all_total_bytes),
overall_speed=all_speed,
overall_eta=all_eta,
status=status,

View File

@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
),
completed_files=0,
total_files=0,
downloaded=Memory.from_bytes(0),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(0),
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",

View File

@@ -1,456 +0,0 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaDoneReason,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaMessage,
OllamaToolCall,
OllamaToolFunction,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _map_done_reason(
finish_reason: str | None,
) -> OllamaDoneReason | None:
if finish_reason is None:
return None
if finish_reason == "stop":
return "stop"
if finish_reason == "length":
return "length"
if finish_reason in ("tool_calls", "function_call"):
return "tool_call"
if finish_reason == "error":
return "error"
return "stop"
def _try_parse_json(value: str) -> dict[str, Any] | str:
try:
return json.loads(value) # type: ignore
except json.JSONDecodeError:
return value
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
tool_calls: list[OllamaToolCall] = []
for index, tool in enumerate(chunk.tool_calls):
# tool.arguments is always str; try to parse as JSON dict for Ollama format
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
tool_calls.append(
OllamaToolCall(
id=tool.id,
type="function",
function=OllamaToolFunction(
name=tool.name, arguments=arguments, index=index
),
)
)
return tool_calls
def _get_usage(
chunk: TokenChunk | ToolCallChunk,
) -> tuple[int | None, int | None]:
"""Extract (prompt_eval_count, eval_count) from a chunk."""
if chunk.usage is not None:
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
if chunk.stats is not None:
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
return (None, None)
def ollama_request_to_text_generation(
request: OllamaChatRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama chat request to exo's internal text generation format."""
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
tool_message_index = 0
for msg in request.messages:
content = msg.content or ""
if msg.role == "system":
if instructions is None:
instructions = content
else:
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
continue
if msg.role in ("user", "assistant") and (
msg.content is not None or msg.thinking is not None or msg.tool_calls
):
input_messages.append(InputMessage(role=msg.role, content=content))
dumped: dict[str, Any] = {"role": msg.role, "content": content}
if msg.thinking is not None:
dumped["thinking"] = msg.thinking
if msg.tool_calls is not None:
tool_calls_list: list[dict[str, Any]] = []
for tc in msg.tool_calls:
function: dict[str, Any] = {
"name": tc.function.name,
"arguments": (
json.dumps(tc.function.arguments)
if isinstance(tc.function.arguments, dict)
else tc.function.arguments
),
}
if tc.function.index is not None:
function["index"] = tc.function.index
tool_call: dict[str, Any] = {"function": function}
if tc.id is not None:
tool_call["id"] = tc.id
if tc.type is not None:
tool_call["type"] = tc.type
tool_calls_list.append(tool_call)
dumped["tool_calls"] = tool_calls_list
if msg.name is not None:
dumped["name"] = msg.name
if msg.role == "tool":
tool_message_index += 1
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
dumped["tool_call_id"] = tool_call_id
if msg.tool_name is not None:
dumped["tool_name"] = msg.tool_name
chat_template_messages.append(dumped)
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
tools=request.tools,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_chat_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
error_response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content=chunk.error_message
),
done=True,
done_reason="error",
)
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content="",
tool_calls=_build_tool_calls(chunk),
thinking="".join(thinking_parts) if thinking_parts else None,
),
done=True,
done_reason="tool_call",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content="", thinking=chunk.text
),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content=chunk.text,
),
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
else:
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(role="assistant", content=chunk.text),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_chat_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect streaming chunks into a single non-streaming Ollama response.
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
StreamingResponse cancellation handling.
"""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[OllamaToolCall] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
tool_calls.extend(_build_tool_calls(chunk))
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None
yield OllamaChatResponse(
model=model,
message=OllamaMessage(
role="assistant",
content=combined_text,
thinking=combined_thinking,
tool_calls=tool_calls if tool_calls else None,
),
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return
# ── /api/generate ──
def ollama_generate_request_to_text_generation(
request: OllamaGenerateRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama generate request to exo's internal text generation format."""
chat_template_messages: list[dict[str, Any]] = []
if request.system:
chat_template_messages.append({"role": "system", "content": request.system})
chat_template_messages.append({"role": "user", "content": request.prompt})
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=[InputMessage(role="user", content=request.prompt)],
instructions=request.system,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_generate_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="error",
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
# generate endpoint doesn't support tools; emit as done
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="stop",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
thinking=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
else:
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_generate_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect chunks into a single non-streaming /api/generate response."""
text_parts: list[str] = []
thinking_parts: list[str] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
assert model is not None
yield OllamaGenerateResponse(
model=model,
response="".join(text_parts),
thinking="".join(thinking_parts) if thinking_parts else None,
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return

View File

@@ -32,14 +32,6 @@ from exo.master.adapters.claude import (
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.ollama import (
collect_ollama_chat_response,
collect_ollama_generate_response,
generate_ollama_chat_stream,
generate_ollama_generate_stream,
ollama_generate_request_to_text_generation,
ollama_request_to_text_generation,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
@@ -149,19 +141,6 @@ from exo.shared.types.events import (
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaModelDetails,
OllamaModelTag,
OllamaPsModel,
OllamaPsResponse,
OllamaShowRequest,
OllamaShowResponse,
OllamaTagsResponse,
)
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
@@ -321,20 +300,6 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
# Ollama API — health checks (must be before static files mount)
self.app.head("/")(self._ollama_root)
self.app.head("/api/version")(self.ollama_version)
# Ollama API
self.app.post("/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/v1/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/generate", response_model=None)(self.ollama_generate)
self.app.get("/api/tags")(self.ollama_tags)
self.app.get("/api/api/tags")(self.ollama_tags)
self.app.get("/api/v1/tags")(self.ollama_tags)
self.app.post("/api/show")(self.ollama_show)
self.app.get("/api/ps")(self.ollama_ps)
self.app.get("/api/version")(self.ollama_version)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(self.stream_events)
self.app.post("/download/start")(self.start_download)
@@ -1328,158 +1293,6 @@ class API:
media_type="application/json",
)
async def _ollama_root(self) -> JSONResponse:
"""Respond to HEAD / from Ollama CLI connectivity checks."""
return JSONResponse(content="Ollama is running")
async def ollama_chat(
self, request: Request
) -> OllamaChatResponse | StreamingResponse:
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaChatRequest.model_validate_json(body)
task_params = ollama_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_generate(
self, request: Request
) -> OllamaGenerateResponse | StreamingResponse:
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaGenerateRequest.model_validate_json(body)
task_params = ollama_generate_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_generate_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_generate_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_tags(self) -> OllamaTagsResponse:
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
def none_if_empty(value: str) -> str | None:
return value or None
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
]
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return OllamaTagsResponse(
models=[
OllamaModelTag(
name=str(card.model_id),
model=str(card.model_id),
modified_at=now,
size=card.storage_size.in_bytes,
digest="sha256:000000000000",
details=OllamaModelDetails(
family=none_if_empty(card.family),
quantization_level=none_if_empty(card.quantization),
),
)
for card in cards
]
)
async def ollama_show(self, request: Request) -> OllamaShowResponse:
"""Returns model information in Ollama show format."""
body = await request.body()
payload = OllamaShowRequest.model_validate_json(body)
try:
card = await ModelCard.load(ModelId(payload.name))
except Exception as exc:
raise HTTPException(
status_code=404, detail=f"Model not found: {payload.name}"
) from exc
return OllamaShowResponse(
details=OllamaModelDetails(
family=card.family or None,
quantization_level=card.quantization or None,
),
)
async def ollama_ps(self) -> OllamaPsResponse:
"""Returns list of running models (active instances)."""
models: list[OllamaPsModel] = []
seen: set[str] = set()
for instance in self.state.instances.values():
model_id = str(instance.shard_assignments.model_id)
if model_id in seen:
continue
seen.add(model_id)
models.append(
OllamaPsModel(
name=model_id,
model=model_id,
size=0,
)
)
return OllamaPsResponse(models=models)
async def ollama_version(self) -> dict[str, str]:
"""Returns version information for Ollama API compatibility."""
return {"version": "exo v1.0"}
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -1509,7 +1322,7 @@ class API:
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=card.storage_size.in_mb,
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),

View File

@@ -102,21 +102,22 @@ def _allocate_and_validate_layers(
layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers,
memory_fractions=[
node_memory[node_id].ram_available / total_memory for node_id in node_ids
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
],
)
total_storage = model_card.storage_size
total_storage_bytes = model_card.storage_size.in_bytes
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available
required_memory = (total_storage_bytes * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
f"but only has {available_memory.in_gb:.2f} GB available"
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
return layer_allocations
@@ -341,7 +342,6 @@ def _find_ip_prioritised(
other_node_id: NodeId,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
ring: bool,
) -> str | None:
"""Find an IP address between nodes with prioritization.
@@ -354,27 +354,13 @@ def _find_ip_prioritised(
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
# TODO: Profile and get actual connection speeds.
if ring:
priority = {
"thunderbolt": 0,
"maybe_ethernet": 1,
"ethernet": 2,
"wifi": 3,
"unknown": 4,
}
# RDMA prefers ethernet coordinator
else:
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
@@ -414,7 +400,7 @@ def get_mlx_ring_hosts_by_node(
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_network, ring=True
node_id, other_node_id, cycle_digraph, node_network
)
if connection_ip is None:
raise ValueError(
@@ -445,9 +431,7 @@ def get_mlx_jaccl_coordinators(
if n == coordinator:
return "0.0.0.0"
ip = _find_ip_prioritised(
n, coordinator, cycle_digraph, node_network, ring=False
)
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
if ip is not None:
return ip

View File

@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size = Memory.from_bytes(
sum(available_memory)
model_card.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size = Memory.from_bytes(1500)
model_card.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()

View File

@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total=Memory(),
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total=Memory(),
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total=Memory(),
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -1,10 +1,10 @@
from math import ceil
from typing import Self, overload
from typing import Self
from exo.utils.pydantic_ext import FrozenModel
from exo.utils.pydantic_ext import CamelCaseModel
class Memory(FrozenModel):
class Memory(CamelCaseModel):
in_bytes: int = 0
@classmethod
@@ -33,22 +33,12 @@ class Memory(FrozenModel):
return cls(in_bytes=round(val * 1024))
@property
def in_mb(self) -> int:
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
return round(self.in_bytes / (1024**2))
@in_mb.setter
def in_mb(self, val: int):
"""Set the megabytes for this memory."""
self.in_bytes = val * (1024**2)
@property
def in_float_mb(self) -> float:
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
def in_mb(self) -> float:
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2)
@in_float_mb.setter
def in_float_mb(self, val: float):
@in_mb.setter
def in_mb(self, val: float):
"""Set the megabytes for this memory, rounded to the nearest byte."""
self.in_bytes = round(val * (1024**2))
@@ -67,85 +57,17 @@ class Memory(FrozenModel):
"""The approximate gigabytes this memory represents."""
return self.in_bytes / (1024**3)
def __add__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes + other.in_bytes)
return NotImplemented
def __add__(self, other: "Memory") -> "Memory":
return Memory.from_bytes(self.in_bytes + other.in_bytes)
def __radd__(self, other: object) -> "Memory":
if other == 0:
return self
return NotImplemented
def __lt__(self, other: Self) -> bool:
return self.in_bytes < other.in_bytes
def __sub__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes - other.in_bytes)
return NotImplemented
def __le__(self, other: Self) -> bool:
return self.in_bytes <= other.in_bytes
def __mul__(self, other: int | float):
return Memory.from_bytes(round(self.in_bytes * other))
def __gt__(self, other: Self) -> bool:
return self.in_bytes > other.in_bytes
def __rmul__(self, other: int | float):
return self * other
@overload
def __truediv__(self, other: "Memory") -> float: ...
@overload
def __truediv__(self, other: int) -> "Memory": ...
@overload
def __truediv__(self, other: float) -> "Memory": ...
def __truediv__(self, other: object) -> "Memory | float":
if isinstance(other, Memory):
return self.in_bytes / other.in_bytes
if isinstance(other, (int, float)):
return Memory.from_bytes(round(self.in_bytes / other))
return NotImplemented
def __floordiv__(self, other: object) -> "Memory":
if isinstance(other, (int, float)):
return Memory.from_bytes(int(self.in_bytes // other))
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes < other.in_bytes
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes <= other.in_bytes
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes > other.in_bytes
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes >= other.in_bytes
return NotImplemented
def __eq__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes == other.in_bytes
return NotImplemented
def __repr__(self) -> str:
return f"Memory.from_bytes({self.in_bytes})"
def __str__(self) -> str:
if self.in_gb > 2:
val = self.in_gb
unit = "GiB"
elif self.in_mb > 2:
val = self.in_mb
unit = "MiB"
elif self.in_kb > 3:
val = self.in_kb
unit = "KiB"
else:
val = self.in_bytes
unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
def __ge__(self, other: Self) -> bool:
return self.in_bytes >= other.in_bytes

View File

@@ -1,147 +0,0 @@
from __future__ import annotations
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelId
# https://github.com/ollama/ollama/blob/main/docs/api.md
OllamaRole = Literal["system", "user", "assistant", "tool"]
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
class OllamaToolFunction(BaseModel, frozen=True):
name: str
arguments: dict[str, Any] | str
index: int | None = None
class OllamaToolCall(BaseModel, frozen=True):
id: str | None = None
type: Literal["function"] | None = None
function: OllamaToolFunction
class OllamaMessage(BaseModel, frozen=True):
role: OllamaRole
content: str | None = None
thinking: str | None = None
tool_calls: list[OllamaToolCall] | None = None
name: str | None = None
tool_name: str | None = None
images: list[str] | None = None
class OllamaOptions(BaseModel, frozen=True):
num_predict: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
class OllamaChatRequest(BaseModel, frozen=True):
model: ModelId
messages: list[OllamaMessage]
stream: bool = True
options: OllamaOptions | None = None
tools: list[dict[str, Any]] | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
class OllamaGenerateRequest(BaseModel, frozen=True):
model: ModelId
prompt: str = ""
system: str | None = None
stream: bool = True
options: OllamaOptions | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
raw: bool = False
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
response: str
thinking: str | None = None
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaShowRequest(BaseModel, frozen=True):
name: str
verbose: bool | None = None
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
message: OllamaMessage
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
format: str | None = None
family: str | None = None
parameter_size: str | None = None
quantization_level: str | None = None
class OllamaModelTag(BaseModel, frozen=True, strict=True):
name: str
model: str | None = None
modified_at: str | None = None
size: int | None = None
digest: str | None = None
details: OllamaModelDetails | None = None
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaModelTag]
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
modelfile: str | None = None
parameters: str | None = None
template: str | None = None
details: OllamaModelDetails | None = None
model_info: dict[str, Any] | None = None
class OllamaPsModel(BaseModel, frozen=True, strict=True):
name: str
model: str
size: int
digest: str | None = None
details: OllamaModelDetails | None = None
expires_at: str | None = None
size_vram: int | None = None
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaPsModel]

View File

@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class DownloadProgressData(CamelCaseModel):
total: Memory
downloaded: Memory
downloaded_this_session: Memory
total_bytes: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
completed_files: int
total_files: int
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
total: Memory
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
shard: ShardMetadata
completed_files: int
total_files: int
downloaded: Memory
downloaded_this_session: Memory
total: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]

View File

@@ -192,13 +192,7 @@ class MpReceiver[T]:
try:
return self.receive_nowait()
except WouldBlock:
try:
item = self._state.buffer.get()
except (TypeError, OSError):
# Queue pipe can get closed while we are blocked on get().
# The underlying connection._handle becomes None, causing
# TypeError in read(handle, remaining).
raise ClosedResourceError from None
item = self._state.buffer.get()
if isinstance(item, _MpEndOfStream):
self.close()
raise EndOfStream from None

View File

@@ -108,7 +108,7 @@ async def check_reachable(
await send.send((target_ip, expected_node_id))
async with (
httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node_id in topology.list_nodes():

View File

@@ -166,7 +166,7 @@ def generate_image(
else 0.0
)
peak_memory = Memory.from_bytes(mx.get_peak_memory())
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
@@ -175,7 +175,7 @@ def generate_image(
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=peak_memory,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()

View File

@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
total_gb = psutil.virtual_memory().total / (1024**3)
if total_gb >= 128:
return 0.85
if total_gb >= 64:

View File

@@ -232,11 +232,11 @@ def shard_and_load(
# Estimate timeout based on model size (5x default for large queued workloads)
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
model_size = get_weights_size(shard_metadata)
timeout_seconds = base_timeout + model_size.in_gb
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size.in_gb:.1f}GB)"
f"(model size: {model_size_gb:.1f}GB)"
)
match shard_metadata:
@@ -642,17 +642,18 @@ def set_wired_limit_for_model(model_size: Memory):
if not mx.metal.is_available():
return
max_rec_size = Memory.from_bytes(
int(mx.metal.device_info()["max_recommended_working_set_size"])
)
if model_size > 0.9 * max_rec_size:
model_bytes = model_size.in_bytes
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
logger.warning(
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
f"Generating with a model that requires {model_mb} MB "
f"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
mx.set_wired_limit(max_rec_size.in_bytes)
mx.set_wired_limit(max_rec_size)
logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -4,7 +4,7 @@ import resource
import time
from collections.abc import Generator
from functools import cache
from typing import TYPE_CHECKING, Literal
from typing import Literal
import mlx.core as mx
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
@@ -588,13 +588,6 @@ def main(
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
if not TYPE_CHECKING:
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
@@ -619,8 +612,12 @@ def main(
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
break

View File

@@ -100,8 +100,8 @@ class RunnerSupervisor:
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(5)
if not self.runner_process.is_alive():
@@ -180,7 +180,6 @@ class RunnerSupervisor:
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
self._event_sender.close()
def __del__(self) -> None:
if self.runner_process.is_alive():
@@ -209,15 +208,10 @@ class RunnerSupervisor:
logger.opt(exception=e).error(f"Runner terminated ({cause})")
try:
await self._event_sender.send(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
except (ClosedResourceError, BrokenResourceError):
logger.warning(
"Event sender already closed, unable to report runner failure"
await self._event_sender.send(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
self.shutdown()

View File

@@ -90,10 +90,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
global_download_status = {
NODE_A: [
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
],
}
@@ -134,7 +138,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Global state shows shard is downloaded for NODE_A
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [],
}
@@ -181,7 +187,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -199,10 +207,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}