Compare commits

..

4 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
e4834dc615 Add support for Ollama API 2026-02-20 00:17:15 +00:00
rltakashige
f662c129dd Prioritise tb for ring instances (#1556)
## Motivation

TB has better bandwidth and latency than ethernet. We should prioritise
TB5 where possible. This drastically improves distributed image
generation performance.

## Test Plan

### Manual Testing
Saw on the dashboard that TB (169.254) addresses were prioritised.

Tested that image models scale much better.

### Automated Testing
No regression on Kimi K2.5
2026-02-19 21:32:48 +00:00
Evan Quiney
c45ff9ad43 memory tidy (#1558)
add some pythonic extensions to memory, did a bunch of cleanup.
2026-02-19 21:15:33 +00:00
rltakashige
7031901ae5 Prevent common fatal crashes (#1555)
## Motivation
Occasionally, memory does not get released when we shut down. There is
no reason to delay deleting the model.

Also handles can become None during shutdown, causing TypeErrors which
are not handled and bringing down exo.

Similarly, we were closing the event sender in the wrong place.

Also let's not verify the SSL certificate for http connections to local
peers, as this is failing sometimes and crashing.

## Test Plan

### Manual Testing
No more crashes as described.
2026-02-19 20:51:17 +00:00
24 changed files with 1016 additions and 548 deletions

View File

@@ -1,232 +0,0 @@
<script lang="ts">
import type {
MetaInstance,
MetaInstanceStatus,
NodeInfo,
} from "$lib/stores/app.svelte";
import {
getMetaInstanceStatus,
getMetaInstanceBackingNodes,
topologyData,
} from "$lib/stores/app.svelte";
interface Props {
metaInstance: MetaInstance;
onDelete?: (metaInstanceId: string) => void;
}
let { metaInstance, onDelete }: Props = $props();
const status: MetaInstanceStatus = $derived(
getMetaInstanceStatus(metaInstance),
);
const backingNodeIds: string[] = $derived(
getMetaInstanceBackingNodes(metaInstance),
);
const statusConfig = $derived.by(() => {
switch (status) {
case "active":
return {
label: "ACTIVE",
dotClass: "bg-green-400",
borderClass:
"border-green-500/30 border-l-green-400",
cornerClass: "border-green-500/50",
glowClass: "shadow-[0_0_6px_rgba(74,222,128,0.4)]",
animate: false,
};
case "provisioning":
return {
label: "PROVISIONING",
dotClass: "bg-yellow-400",
borderClass:
"border-exo-yellow/30 border-l-yellow-400",
cornerClass: "border-yellow-500/50",
glowClass: "shadow-[0_0_6px_rgba(250,204,21,0.4)]",
animate: true,
};
case "error":
return {
label: "ERROR",
dotClass: "bg-red-400",
borderClass: "border-red-500/30 border-l-red-400",
cornerClass: "border-red-500/50",
glowClass: "shadow-[0_0_6px_rgba(248,113,113,0.4)]",
animate: false,
};
}
});
function getNodeName(nodeId: string): string {
const topo = topologyData();
if (!topo?.nodes) return nodeId.slice(0, 8);
const node = topo.nodes[nodeId];
return node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8);
}
function formatModelId(modelId: string): string {
// Show just the model name part after the org prefix
const parts = modelId.split("/");
return parts.length > 1 ? parts[parts.length - 1] : modelId;
}
function handleDelete() {
if (
onDelete &&
confirm(
`Delete meta-instance for ${formatModelId(metaInstance.modelId)}?`,
)
) {
onDelete(metaInstance.metaInstanceId);
}
}
</script>
<div class="relative group">
<!-- Corner accents -->
<div
class="absolute -top-px -left-px w-2 h-2 border-l border-t {statusConfig.cornerClass}"
></div>
<div
class="absolute -top-px -right-px w-2 h-2 border-r border-t {statusConfig.cornerClass}"
></div>
<div
class="absolute -bottom-px -left-px w-2 h-2 border-l border-b {statusConfig.cornerClass}"
></div>
<div
class="absolute -bottom-px -right-px w-2 h-2 border-r border-b {statusConfig.cornerClass}"
></div>
<div
class="bg-exo-dark-gray/60 border border-l-2 {statusConfig.borderClass} p-3"
>
<!-- Header: Status + Delete -->
<div class="flex justify-between items-start mb-2 pl-2">
<div class="flex items-center gap-2">
<div
class="w-1.5 h-1.5 {statusConfig.dotClass} rounded-full {statusConfig.glowClass} {statusConfig.animate
? 'animate-pulse'
: ''}"
></div>
<span
class="text-xs font-mono tracking-[0.15em] uppercase {status === 'active'
? 'text-green-400'
: status === 'error'
? 'text-red-400'
: 'text-yellow-400'}"
>
{statusConfig.label}
</span>
</div>
<button
onclick={handleDelete}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<!-- Model Info -->
<div class="pl-2 space-y-1">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">
{metaInstance.modelId}
</div>
<!-- Sharding + Runtime badges -->
<div class="flex items-center gap-2">
<span
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
>
{metaInstance.sharding}
</span>
<span
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
>
{metaInstance.instanceMeta}
</span>
{#if metaInstance.minNodes > 1}
<span
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
>
{metaInstance.minNodes}+ nodes
</span>
{/if}
</div>
<!-- Node Assignments (when active) -->
{#if backingNodeIds.length > 0}
<div class="flex items-center gap-1.5 mt-1">
<svg
class="w-3 h-3 text-green-400/70 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M22 12h-4l-3 9L9 3l-3 9H2"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
<span class="text-white/60 text-xs font-mono truncate">
{backingNodeIds.map((id) => getNodeName(id)).join(", ")}
</span>
</div>
{/if}
<!-- Pinned nodes constraint -->
{#if metaInstance.nodeIds && metaInstance.nodeIds.length > 0}
<div class="flex items-center gap-1.5">
<svg
class="w-3 h-3 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="11" width="18" height="11" rx="2" ry="2" />
<path d="M7 11V7a5 5 0 0 1 10 0v4" />
</svg>
<span class="text-white/40 text-[11px] font-mono">
Pinned: {metaInstance.nodeIds
.map((id) => getNodeName(id))
.join(", ")}
</span>
</div>
{/if}
<!-- Error details -->
{#if metaInstance.placementError}
<div
class="mt-1.5 p-2 bg-red-500/5 border border-red-500/15 rounded-sm"
>
<div class="text-red-400 text-[11px] font-mono leading-relaxed">
{metaInstance.placementError}
</div>
</div>
{/if}
<!-- Retry counter -->
{#if metaInstance.consecutiveFailures > 0}
<div class="flex items-center gap-1.5 mt-1">
<svg
class="w-3 h-3 text-yellow-500/60 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<polyline points="23 4 23 10 17 10" />
<path d="M20.49 15a9 9 0 1 1-2.12-9.36L23 10" />
</svg>
<span class="text-yellow-500/60 text-[11px] font-mono">
{metaInstance.consecutiveFailures} consecutive
failure{metaInstance.consecutiveFailures !== 1 ? "s" : ""}
</span>
</div>
{/if}
</div>
</div>
</div>

View File

@@ -11,5 +11,4 @@ export { default as FamilySidebar } from "./FamilySidebar.svelte";
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
export { default as MetaInstanceCard } from "./MetaInstanceCard.svelte";
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";

View File

@@ -72,23 +72,8 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
metaInstanceId?: string | null;
}
export interface MetaInstance {
metaInstanceId: string;
modelId: string;
sharding: "Pipeline" | "Tensor";
instanceMeta: "MlxRing" | "MlxJaccl";
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
}
export type MetaInstanceStatus = "active" | "provisioning" | "error";
// Granular node state types from the new state structure
interface RawNodeIdentity {
modelId?: string;
@@ -238,7 +223,6 @@ interface RawStateResponse {
MlxJacclInstance?: Instance;
}
>;
metaInstances?: Record<string, MetaInstance>;
runners?: Record<string, unknown>;
downloads?: Record<string, unknown[]>;
// New granular node state fields
@@ -266,6 +250,11 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// Disk usage per node
nodeDisk?: Record<
string,
{ total: { inBytes: number }; available: { inBytes: number } }
>;
}
export interface MessageAttachment {
@@ -549,7 +538,6 @@ class AppStore {
// Topology state
topologyData = $state<TopologyData | null>(null);
instances = $state<Record<string, unknown>>({});
metaInstances = $state<Record<string, MetaInstance>>({});
runners = $state<Record<string, unknown>>({});
downloads = $state<Record<string, unknown[]>>({});
nodeDisk = $state<
@@ -1285,9 +1273,6 @@ class AppStore {
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
if (data.metaInstances) {
this.metaInstances = data.metaInstances;
}
if (data.runners) {
this.runners = data.runners;
}
@@ -1313,79 +1298,6 @@ class AppStore {
}
}
async createMetaInstance(
modelId: string,
sharding: "Pipeline" | "Tensor" = "Pipeline",
instanceMeta: "MlxRing" | "MlxJaccl" = "MlxRing",
minNodes: number = 1,
nodeIds: string[] | null = null,
) {
try {
const response = await fetch("/meta_instance", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model_id: modelId,
sharding,
instance_meta: instanceMeta,
min_nodes: minNodes,
node_ids: nodeIds,
}),
});
if (!response.ok) {
console.error("Failed to create meta-instance:", response.status);
}
await this.fetchState();
} catch (error) {
console.error("Error creating meta-instance:", error);
}
}
async deleteMetaInstance(metaInstanceId: string) {
try {
const response = await fetch(`/meta_instance/${metaInstanceId}`, {
method: "DELETE",
headers: { "Content-Type": "application/json" },
});
if (!response.ok) {
console.error("Failed to delete meta-instance:", response.status);
}
await this.fetchState();
} catch (error) {
console.error("Error deleting meta-instance:", error);
}
}
getMetaInstanceStatus(
metaInstance: MetaInstance,
): MetaInstanceStatus {
// Check if any running instance is bound to this meta-instance
for (const instanceWrapper of Object.values(this.instances)) {
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
if (keys.length !== 1) continue;
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]];
if (inner && typeof inner === "object" && (inner as Instance).metaInstanceId === metaInstance.metaInstanceId) {
return "active";
}
}
if (metaInstance.placementError) return "error";
return "provisioning";
}
getMetaInstanceBackingNodes(metaInstance: MetaInstance): string[] {
for (const instanceWrapper of Object.values(this.instances)) {
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
if (keys.length !== 1) continue;
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]] as Instance;
if (inner?.metaInstanceId === metaInstance.metaInstanceId && inner?.shardAssignments?.nodeToRunner) {
return Object.keys(inner.shardAssignments.nodeToRunner);
}
}
return [];
}
async fetchPlacementPreviews(modelId: string, showLoading = true) {
if (!modelId) return;
@@ -3247,7 +3159,6 @@ export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;
@@ -3336,21 +3247,6 @@ export const setChatSidebarVisible = (visible: boolean) =>
appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();
// Meta-instance actions
export const createMetaInstance = (
modelId: string,
sharding?: "Pipeline" | "Tensor",
instanceMeta?: "MlxRing" | "MlxJaccl",
minNodes?: number,
nodeIds?: string[] | null,
) => appStore.createMetaInstance(modelId, sharding, instanceMeta, minNodes, nodeIds);
export const deleteMetaInstance = (metaInstanceId: string) =>
appStore.deleteMetaInstance(metaInstanceId);
export const getMetaInstanceStatus = (metaInstance: MetaInstance) =>
appStore.getMetaInstanceStatus(metaInstance);
export const getMetaInstanceBackingNodes = (metaInstance: MetaInstance) =>
appStore.getMetaInstanceBackingNodes(metaInstance);
// Node identities (for OS version mismatch detection)
export const nodeIdentities = () => appStore.nodeIdentities;

View File

@@ -47,14 +47,10 @@
thunderboltBridgeCycles,
nodeThunderboltBridge,
nodeIdentities,
metaInstances,
deleteMetaInstance,
type DownloadProgress,
type PlacementPreview,
type MetaInstance,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
import MetaInstanceCard from "$lib/components/MetaInstanceCard.svelte";
import { fade, fly } from "svelte/transition";
import { cubicInOut } from "svelte/easing";
import { onMount } from "svelte";
@@ -71,8 +67,6 @@
const loadingPreviews = $derived(isLoadingPreviews());
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const metaInstanceData = $derived(metaInstances());
const metaInstanceCount = $derived(Object.keys(metaInstanceData).length);
const sidebarVisible = $derived(chatSidebarVisible());
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
const tbBridgeData = $derived(nodeThunderboltBridge());
@@ -864,10 +858,8 @@
if (!progress || typeof progress !== "object") return null;
const prog = progress as Record<string, unknown>;
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const totalBytes = getBytes(prog.total);
const downloadedBytes = getBytes(prog.downloaded);
const speed = (prog.speed as number) ?? 0;
const completedFiles =
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
@@ -880,8 +872,8 @@
for (const [fileName, fileData] of Object.entries(filesObj)) {
if (!fileData || typeof fileData !== "object") continue;
const fd = fileData as Record<string, unknown>;
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
const fTotal = getBytes(fd.total);
const fDownloaded = getBytes(fd.downloaded);
files.push({
name: fileName,
totalBytes: fTotal,
@@ -1270,7 +1262,6 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;
@@ -3062,39 +3053,6 @@
</div>
{/if}
<!-- Meta-Instances Panel -->
{#if metaInstanceCount > 0}
<div class="p-4 flex-shrink-0 border-t border-exo-yellow/10">
<!-- Panel Header -->
<div class="flex items-center gap-2 mb-4">
<div
class="w-2 h-2 border border-purple-400/60 rotate-45"
></div>
<h3
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
>
Meta-Instances
</h3>
<div
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
></div>
<span class="text-[10px] text-white/40 font-mono"
>{metaInstanceCount}</span
>
</div>
<div class="space-y-3">
{#each Object.entries(metaInstanceData) as [id, mi]}
<MetaInstanceCard
metaInstance={mi}
onDelete={(metaInstanceId) =>
deleteMetaInstance(metaInstanceId)}
/>
{/each}
</div>
</div>
{/if}
<!-- Models Panel - Scrollable -->
<div class="p-4 flex-1 overflow-y-auto">
<!-- Panel Header -->
@@ -3917,34 +3875,6 @@
</div>
</div>
{/if}
<!-- Meta-Instances Section (chat sidebar) -->
{#if metaInstanceCount > 0}
<div class="p-4 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-4">
<div
class="w-2 h-2 border border-purple-400/60 rotate-45"
></div>
<h3
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
>
Meta-Instances
</h3>
<div
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
></div>
</div>
<div class="space-y-3">
{#each Object.entries(metaInstanceData) as [id, mi]}
<MetaInstanceCard
metaInstance={mi}
onDelete={(metaInstanceId) =>
deleteMetaInstance(metaInstanceId)}
/>
{/each}
</div>
</div>
{/if}
</aside>
{/if}
</div>

View File

@@ -74,7 +74,6 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;
@@ -231,23 +230,14 @@
undefined;
let cell: CellStatus;
if (tag === "DownloadCompleted") {
const totalBytes = getBytes(
payload.total_bytes ?? payload.totalBytes,
);
const totalBytes = getBytes(payload.total);
cell = { kind: "completed", totalBytes, modelDirectory };
} else if (tag === "DownloadOngoing") {
const rawProgress =
payload.download_progress ?? payload.downloadProgress ?? {};
const prog = rawProgress as Record<string, unknown>;
const totalBytes = getBytes(
prog.total_bytes ??
prog.totalBytes ??
payload.total_bytes ??
payload.totalBytes,
);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const totalBytes = getBytes(prog.total ?? payload.total);
const downloadedBytes = getBytes(prog.downloaded);
const speed = (prog.speed as number) ?? 0;
const etaMs =
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;

View File

@@ -80,7 +80,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
total=progress.total,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -203,7 +203,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
total=initial_progress.total,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -332,13 +332,13 @@ class DownloadCoordinator:
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
total=progress.total,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
if progress.downloaded_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
repo_file_download_progress: RepoFileDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
downloaded_bytes=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
total_bytes=repo_file_download_progress.total,
downloaded=repo_file_download_progress.downloaded,
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
total=repo_file_download_progress.total,
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1,
speed=repo_file_download_progress.speed,
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
repo_download_progress: RepoDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
total_bytes=repo_download_progress.total_bytes,
downloaded_bytes=repo_download_progress.downloaded_bytes,
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
total=repo_download_progress.total,
downloaded=repo_download_progress.downloaded,
downloaded_this_session=repo_download_progress.downloaded_this_session,
completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
@@ -578,19 +578,20 @@ def calculate_repo_progress(
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
) -> RepoDownloadProgress:
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum(
(p.downloaded.in_bytes for p in file_progress.values()), 0
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
all_downloaded = sum(
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
)
all_downloaded_bytes_this_session = sum(
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
all_downloaded_this_session = sum(
(p.downloaded_this_session for p in file_progress.values()),
Memory.from_bytes(0),
)
elapsed_time = time.time() - all_start_time
all_speed = (
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
)
all_eta = (
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
if all_speed > 0
else timedelta(seconds=0)
)
@@ -609,11 +610,9 @@ def calculate_repo_progress(
[p for p in file_progress.values() if p.downloaded == p.total]
),
total_files=len(file_progress),
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(
all_downloaded_bytes_this_session
),
total_bytes=Memory.from_bytes(all_total_bytes),
downloaded=all_downloaded,
downloaded_this_session=all_downloaded_this_session,
total=all_total,
overall_speed=all_speed,
overall_eta=all_eta,
status=status,

View File

@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
downloaded=Memory.from_bytes(0),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",

View File

@@ -0,0 +1,456 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaDoneReason,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaMessage,
OllamaToolCall,
OllamaToolFunction,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _map_done_reason(
finish_reason: str | None,
) -> OllamaDoneReason | None:
if finish_reason is None:
return None
if finish_reason == "stop":
return "stop"
if finish_reason == "length":
return "length"
if finish_reason in ("tool_calls", "function_call"):
return "tool_call"
if finish_reason == "error":
return "error"
return "stop"
def _try_parse_json(value: str) -> dict[str, Any] | str:
try:
return json.loads(value) # type: ignore
except json.JSONDecodeError:
return value
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
tool_calls: list[OllamaToolCall] = []
for index, tool in enumerate(chunk.tool_calls):
# tool.arguments is always str; try to parse as JSON dict for Ollama format
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
tool_calls.append(
OllamaToolCall(
id=tool.id,
type="function",
function=OllamaToolFunction(
name=tool.name, arguments=arguments, index=index
),
)
)
return tool_calls
def _get_usage(
chunk: TokenChunk | ToolCallChunk,
) -> tuple[int | None, int | None]:
"""Extract (prompt_eval_count, eval_count) from a chunk."""
if chunk.usage is not None:
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
if chunk.stats is not None:
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
return (None, None)
def ollama_request_to_text_generation(
request: OllamaChatRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama chat request to exo's internal text generation format."""
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
tool_message_index = 0
for msg in request.messages:
content = msg.content or ""
if msg.role == "system":
if instructions is None:
instructions = content
else:
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
continue
if msg.role in ("user", "assistant") and (
msg.content is not None or msg.thinking is not None or msg.tool_calls
):
input_messages.append(InputMessage(role=msg.role, content=content))
dumped: dict[str, Any] = {"role": msg.role, "content": content}
if msg.thinking is not None:
dumped["thinking"] = msg.thinking
if msg.tool_calls is not None:
tool_calls_list: list[dict[str, Any]] = []
for tc in msg.tool_calls:
function: dict[str, Any] = {
"name": tc.function.name,
"arguments": (
json.dumps(tc.function.arguments)
if isinstance(tc.function.arguments, dict)
else tc.function.arguments
),
}
if tc.function.index is not None:
function["index"] = tc.function.index
tool_call: dict[str, Any] = {"function": function}
if tc.id is not None:
tool_call["id"] = tc.id
if tc.type is not None:
tool_call["type"] = tc.type
tool_calls_list.append(tool_call)
dumped["tool_calls"] = tool_calls_list
if msg.name is not None:
dumped["name"] = msg.name
if msg.role == "tool":
tool_message_index += 1
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
dumped["tool_call_id"] = tool_call_id
if msg.tool_name is not None:
dumped["tool_name"] = msg.tool_name
chat_template_messages.append(dumped)
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
tools=request.tools,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_chat_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
error_response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content=chunk.error_message
),
done=True,
done_reason="error",
)
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content="",
tool_calls=_build_tool_calls(chunk),
thinking="".join(thinking_parts) if thinking_parts else None,
),
done=True,
done_reason="tool_call",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content="", thinking=chunk.text
),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content=chunk.text,
),
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
else:
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(role="assistant", content=chunk.text),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_chat_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect streaming chunks into a single non-streaming Ollama response.
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
StreamingResponse cancellation handling.
"""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[OllamaToolCall] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
tool_calls.extend(_build_tool_calls(chunk))
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None
yield OllamaChatResponse(
model=model,
message=OllamaMessage(
role="assistant",
content=combined_text,
thinking=combined_thinking,
tool_calls=tool_calls if tool_calls else None,
),
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return
# ── /api/generate ──
def ollama_generate_request_to_text_generation(
request: OllamaGenerateRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama generate request to exo's internal text generation format."""
chat_template_messages: list[dict[str, Any]] = []
if request.system:
chat_template_messages.append({"role": "system", "content": request.system})
chat_template_messages.append({"role": "user", "content": request.prompt})
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=[InputMessage(role="user", content=request.prompt)],
instructions=request.system,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_generate_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="error",
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
# generate endpoint doesn't support tools; emit as done
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="stop",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
thinking=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
else:
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_generate_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect chunks into a single non-streaming /api/generate response."""
text_parts: list[str] = []
thinking_parts: list[str] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
assert model is not None
yield OllamaGenerateResponse(
model=model,
response="".join(text_parts),
thinking="".join(thinking_parts) if thinking_parts else None,
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return

View File

@@ -32,6 +32,14 @@ from exo.master.adapters.claude import (
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.ollama import (
collect_ollama_chat_response,
collect_ollama_generate_response,
generate_ollama_chat_stream,
generate_ollama_generate_stream,
ollama_generate_request_to_text_generation,
ollama_request_to_text_generation,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
@@ -141,6 +149,19 @@ from exo.shared.types.events import (
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaModelDetails,
OllamaModelTag,
OllamaPsModel,
OllamaPsResponse,
OllamaShowRequest,
OllamaShowResponse,
OllamaTagsResponse,
)
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
@@ -300,6 +321,20 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
# Ollama API — health checks (must be before static files mount)
self.app.head("/")(self._ollama_root)
self.app.head("/api/version")(self.ollama_version)
# Ollama API
self.app.post("/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/v1/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/generate", response_model=None)(self.ollama_generate)
self.app.get("/api/tags")(self.ollama_tags)
self.app.get("/api/api/tags")(self.ollama_tags)
self.app.get("/api/v1/tags")(self.ollama_tags)
self.app.post("/api/show")(self.ollama_show)
self.app.get("/api/ps")(self.ollama_ps)
self.app.get("/api/version")(self.ollama_version)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(self.stream_events)
self.app.post("/download/start")(self.start_download)
@@ -1293,6 +1328,158 @@ class API:
media_type="application/json",
)
async def _ollama_root(self) -> JSONResponse:
"""Respond to HEAD / from Ollama CLI connectivity checks."""
return JSONResponse(content="Ollama is running")
async def ollama_chat(
self, request: Request
) -> OllamaChatResponse | StreamingResponse:
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaChatRequest.model_validate_json(body)
task_params = ollama_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_generate(
self, request: Request
) -> OllamaGenerateResponse | StreamingResponse:
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaGenerateRequest.model_validate_json(body)
task_params = ollama_generate_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_generate_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_generate_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_tags(self) -> OllamaTagsResponse:
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
def none_if_empty(value: str) -> str | None:
return value or None
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
]
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return OllamaTagsResponse(
models=[
OllamaModelTag(
name=str(card.model_id),
model=str(card.model_id),
modified_at=now,
size=card.storage_size.in_bytes,
digest="sha256:000000000000",
details=OllamaModelDetails(
family=none_if_empty(card.family),
quantization_level=none_if_empty(card.quantization),
),
)
for card in cards
]
)
async def ollama_show(self, request: Request) -> OllamaShowResponse:
"""Returns model information in Ollama show format."""
body = await request.body()
payload = OllamaShowRequest.model_validate_json(body)
try:
card = await ModelCard.load(ModelId(payload.name))
except Exception as exc:
raise HTTPException(
status_code=404, detail=f"Model not found: {payload.name}"
) from exc
return OllamaShowResponse(
details=OllamaModelDetails(
family=card.family or None,
quantization_level=card.quantization or None,
),
)
async def ollama_ps(self) -> OllamaPsResponse:
"""Returns list of running models (active instances)."""
models: list[OllamaPsModel] = []
seen: set[str] = set()
for instance in self.state.instances.values():
model_id = str(instance.shard_assignments.model_id)
if model_id in seen:
continue
seen.add(model_id)
models.append(
OllamaPsModel(
name=model_id,
model=model_id,
size=0,
)
)
return OllamaPsResponse(models=models)
async def ollama_version(self) -> dict[str, str]:
"""Returns version information for Ollama API compatibility."""
return {"version": "exo v1.0"}
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -1322,7 +1509,7 @@ class API:
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
storage_size_megabytes=card.storage_size.in_mb,
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),

View File

@@ -102,22 +102,21 @@ def _allocate_and_validate_layers(
layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
node_memory[node_id].ram_available / total_memory for node_id in node_ids
],
)
total_storage_bytes = model_card.storage_size.in_bytes
total_storage = model_card.storage_size
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage_bytes * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available.in_bytes
required_memory = (total_storage * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
f"but only has {available_memory.in_gb:.2f} GB available"
)
return layer_allocations
@@ -342,6 +341,7 @@ def _find_ip_prioritised(
other_node_id: NodeId,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
ring: bool,
) -> str | None:
"""Find an IP address between nodes with prioritization.
@@ -354,13 +354,27 @@ def _find_ip_prioritised(
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
# TODO: Profile and get actual connection speeds.
if ring:
priority = {
"thunderbolt": 0,
"maybe_ethernet": 1,
"ethernet": 2,
"wifi": 3,
"unknown": 4,
}
# RDMA prefers ethernet coordinator
else:
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
@@ -400,7 +414,7 @@ def get_mlx_ring_hosts_by_node(
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_network
node_id, other_node_id, cycle_digraph, node_network, ring=True
)
if connection_ip is None:
raise ValueError(
@@ -431,7 +445,9 @@ def get_mlx_jaccl_coordinators(
if n == coordinator:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
ip = _find_ip_prioritised(
n, coordinator, cycle_digraph, node_network, ring=False
)
if ip is not None:
return ip

View File

@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
available_memory
model_card.storage_size = Memory.from_bytes(
sum(available_memory)
) # make it exactly fit across all nodes
topology = Topology()
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_card.storage_size = Memory.from_bytes(1500)
node_a = NodeId()
node_b = NodeId()

View File

@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
total=Memory(),
)
new_state = apply_node_download_progress(
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
total=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
total=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -1,10 +1,10 @@
from math import ceil
from typing import Self
from typing import Self, overload
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import FrozenModel
class Memory(CamelCaseModel):
class Memory(FrozenModel):
in_bytes: int = 0
@classmethod
@@ -33,12 +33,22 @@ class Memory(CamelCaseModel):
return cls(in_bytes=round(val * 1024))
@property
def in_mb(self) -> float:
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2)
def in_mb(self) -> int:
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
return round(self.in_bytes / (1024**2))
@in_mb.setter
def in_mb(self, val: float):
def in_mb(self, val: int):
"""Set the megabytes for this memory."""
self.in_bytes = val * (1024**2)
@property
def in_float_mb(self) -> float:
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2)
@in_float_mb.setter
def in_float_mb(self, val: float):
"""Set the megabytes for this memory, rounded to the nearest byte."""
self.in_bytes = round(val * (1024**2))
@@ -57,17 +67,85 @@ class Memory(CamelCaseModel):
"""The approximate gigabytes this memory represents."""
return self.in_bytes / (1024**3)
def __add__(self, other: "Memory") -> "Memory":
return Memory.from_bytes(self.in_bytes + other.in_bytes)
def __add__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes + other.in_bytes)
return NotImplemented
def __lt__(self, other: Self) -> bool:
return self.in_bytes < other.in_bytes
def __radd__(self, other: object) -> "Memory":
if other == 0:
return self
return NotImplemented
def __le__(self, other: Self) -> bool:
return self.in_bytes <= other.in_bytes
def __sub__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes - other.in_bytes)
return NotImplemented
def __gt__(self, other: Self) -> bool:
return self.in_bytes > other.in_bytes
def __mul__(self, other: int | float):
return Memory.from_bytes(round(self.in_bytes * other))
def __ge__(self, other: Self) -> bool:
return self.in_bytes >= other.in_bytes
def __rmul__(self, other: int | float):
return self * other
@overload
def __truediv__(self, other: "Memory") -> float: ...
@overload
def __truediv__(self, other: int) -> "Memory": ...
@overload
def __truediv__(self, other: float) -> "Memory": ...
def __truediv__(self, other: object) -> "Memory | float":
if isinstance(other, Memory):
return self.in_bytes / other.in_bytes
if isinstance(other, (int, float)):
return Memory.from_bytes(round(self.in_bytes / other))
return NotImplemented
def __floordiv__(self, other: object) -> "Memory":
if isinstance(other, (int, float)):
return Memory.from_bytes(int(self.in_bytes // other))
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes < other.in_bytes
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes <= other.in_bytes
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes > other.in_bytes
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes >= other.in_bytes
return NotImplemented
def __eq__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes == other.in_bytes
return NotImplemented
def __repr__(self) -> str:
return f"Memory.from_bytes({self.in_bytes})"
def __str__(self) -> str:
if self.in_gb > 2:
val = self.in_gb
unit = "GiB"
elif self.in_mb > 2:
val = self.in_mb
unit = "MiB"
elif self.in_kb > 3:
val = self.in_kb
unit = "KiB"
else:
val = self.in_bytes
unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelId
# https://github.com/ollama/ollama/blob/main/docs/api.md
OllamaRole = Literal["system", "user", "assistant", "tool"]
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
class OllamaToolFunction(BaseModel, frozen=True):
name: str
arguments: dict[str, Any] | str
index: int | None = None
class OllamaToolCall(BaseModel, frozen=True):
id: str | None = None
type: Literal["function"] | None = None
function: OllamaToolFunction
class OllamaMessage(BaseModel, frozen=True):
role: OllamaRole
content: str | None = None
thinking: str | None = None
tool_calls: list[OllamaToolCall] | None = None
name: str | None = None
tool_name: str | None = None
images: list[str] | None = None
class OllamaOptions(BaseModel, frozen=True):
num_predict: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
class OllamaChatRequest(BaseModel, frozen=True):
model: ModelId
messages: list[OllamaMessage]
stream: bool = True
options: OllamaOptions | None = None
tools: list[dict[str, Any]] | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
class OllamaGenerateRequest(BaseModel, frozen=True):
model: ModelId
prompt: str = ""
system: str | None = None
stream: bool = True
options: OllamaOptions | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
raw: bool = False
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
response: str
thinking: str | None = None
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaShowRequest(BaseModel, frozen=True):
name: str
verbose: bool | None = None
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
message: OllamaMessage
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
format: str | None = None
family: str | None = None
parameter_size: str | None = None
quantization_level: str | None = None
class OllamaModelTag(BaseModel, frozen=True, strict=True):
name: str
model: str | None = None
modified_at: str | None = None
size: int | None = None
digest: str | None = None
details: OllamaModelDetails | None = None
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaModelTag]
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
modelfile: str | None = None
parameters: str | None = None
template: str | None = None
details: OllamaModelDetails | None = None
model_info: dict[str, Any] | None = None
class OllamaPsModel(BaseModel, frozen=True, strict=True):
name: str
model: str
size: int
digest: str | None = None
details: OllamaModelDetails | None = None
expires_at: str | None = None
size_vram: int | None = None
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaPsModel]

View File

@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class DownloadProgressData(CamelCaseModel):
total_bytes: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total: Memory
downloaded: Memory
downloaded_this_session: Memory
completed_files: int
total_files: int
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
total_bytes: Memory
total: Memory
class DownloadFailed(BaseDownloadProgress):
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
downloaded: Memory
downloaded_this_session: Memory
total: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]

View File

@@ -192,7 +192,13 @@ class MpReceiver[T]:
try:
return self.receive_nowait()
except WouldBlock:
item = self._state.buffer.get()
try:
item = self._state.buffer.get()
except (TypeError, OSError):
# Queue pipe can get closed while we are blocked on get().
# The underlying connection._handle becomes None, causing
# TypeError in read(handle, remaining).
raise ClosedResourceError from None
if isinstance(item, _MpEndOfStream):
self.close()
raise EndOfStream from None

View File

@@ -108,7 +108,7 @@ async def check_reachable(
await send.send((target_ip, expected_node_id))
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,
create_task_group() as tg,
):
for node_id in topology.list_nodes():

View File

@@ -166,7 +166,7 @@ def generate_image(
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
peak_memory = Memory.from_bytes(mx.get_peak_memory())
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
@@ -175,7 +175,7 @@ def generate_image(
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
peak_memory_usage=peak_memory,
)
buffer = io.BytesIO()

View File

@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = psutil.virtual_memory().total / (1024**3)
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:

View File

@@ -232,11 +232,11 @@ def shard_and_load(
# Estimate timeout based on model size (5x default for large queued workloads)
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb
model_size = get_weights_size(shard_metadata)
timeout_seconds = base_timeout + model_size.in_gb
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
f"(model size: {model_size.in_gb:.1f}GB)"
)
match shard_metadata:
@@ -642,18 +642,17 @@ def set_wired_limit_for_model(model_size: Memory):
if not mx.metal.is_available():
return
model_bytes = model_size.in_bytes
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
max_rec_size = Memory.from_bytes(
int(mx.metal.device_info()["max_recommended_working_set_size"])
)
if model_size > 0.9 * max_rec_size:
logger.warning(
f"Generating with a model that requires {model_mb} MB "
f"which is close to the maximum recommended size of {max_rec_mb} "
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
mx.set_wired_limit(max_rec_size)
mx.set_wired_limit(max_rec_size.in_bytes)
logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -4,7 +4,7 @@ import resource
import time
from collections.abc import Generator
from functools import cache
from typing import Literal
from typing import TYPE_CHECKING, Literal
import mlx.core as mx
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
@@ -588,6 +588,13 @@ def main(
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
if not TYPE_CHECKING:
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
@@ -612,12 +619,8 @@ def main(
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
if isinstance(current_status, RunnerShutdown):
break

View File

@@ -100,8 +100,8 @@ class RunnerSupervisor:
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(5)
if not self.runner_process.is_alive():
@@ -180,6 +180,7 @@ class RunnerSupervisor:
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
self._event_sender.close()
def __del__(self) -> None:
if self.runner_process.is_alive():
@@ -208,10 +209,15 @@ class RunnerSupervisor:
logger.opt(exception=e).error(f"Runner terminated ({cause})")
await self._event_sender.send(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
try:
await self._event_sender.send(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
except (ClosedResourceError, BrokenResourceError):
logger.warning(
"Event sender already closed, unable to report runner failure"
)
)
self.shutdown()

View File

@@ -90,14 +90,10 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
],
}
@@ -138,9 +134,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Global state shows shard is downloaded for NODE_A
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
],
NODE_B: [],
}
@@ -187,9 +181,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -207,14 +199,10 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
], # NODE_B has no downloads completed yet
}