mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
1 Commits
leo/add-ol
...
feat/meta-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0825335c7 |
232
dashboard/src/lib/components/MetaInstanceCard.svelte
Normal file
232
dashboard/src/lib/components/MetaInstanceCard.svelte
Normal file
@@ -0,0 +1,232 @@
|
||||
<script lang="ts">
|
||||
import type {
|
||||
MetaInstance,
|
||||
MetaInstanceStatus,
|
||||
NodeInfo,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import {
|
||||
getMetaInstanceStatus,
|
||||
getMetaInstanceBackingNodes,
|
||||
topologyData,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
metaInstance: MetaInstance;
|
||||
onDelete?: (metaInstanceId: string) => void;
|
||||
}
|
||||
|
||||
let { metaInstance, onDelete }: Props = $props();
|
||||
|
||||
const status: MetaInstanceStatus = $derived(
|
||||
getMetaInstanceStatus(metaInstance),
|
||||
);
|
||||
const backingNodeIds: string[] = $derived(
|
||||
getMetaInstanceBackingNodes(metaInstance),
|
||||
);
|
||||
|
||||
const statusConfig = $derived.by(() => {
|
||||
switch (status) {
|
||||
case "active":
|
||||
return {
|
||||
label: "ACTIVE",
|
||||
dotClass: "bg-green-400",
|
||||
borderClass:
|
||||
"border-green-500/30 border-l-green-400",
|
||||
cornerClass: "border-green-500/50",
|
||||
glowClass: "shadow-[0_0_6px_rgba(74,222,128,0.4)]",
|
||||
animate: false,
|
||||
};
|
||||
case "provisioning":
|
||||
return {
|
||||
label: "PROVISIONING",
|
||||
dotClass: "bg-yellow-400",
|
||||
borderClass:
|
||||
"border-exo-yellow/30 border-l-yellow-400",
|
||||
cornerClass: "border-yellow-500/50",
|
||||
glowClass: "shadow-[0_0_6px_rgba(250,204,21,0.4)]",
|
||||
animate: true,
|
||||
};
|
||||
case "error":
|
||||
return {
|
||||
label: "ERROR",
|
||||
dotClass: "bg-red-400",
|
||||
borderClass: "border-red-500/30 border-l-red-400",
|
||||
cornerClass: "border-red-500/50",
|
||||
glowClass: "shadow-[0_0_6px_rgba(248,113,113,0.4)]",
|
||||
animate: false,
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
function getNodeName(nodeId: string): string {
|
||||
const topo = topologyData();
|
||||
if (!topo?.nodes) return nodeId.slice(0, 8);
|
||||
const node = topo.nodes[nodeId];
|
||||
return node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8);
|
||||
}
|
||||
|
||||
function formatModelId(modelId: string): string {
|
||||
// Show just the model name part after the org prefix
|
||||
const parts = modelId.split("/");
|
||||
return parts.length > 1 ? parts[parts.length - 1] : modelId;
|
||||
}
|
||||
|
||||
function handleDelete() {
|
||||
if (
|
||||
onDelete &&
|
||||
confirm(
|
||||
`Delete meta-instance for ${formatModelId(metaInstance.modelId)}?`,
|
||||
)
|
||||
) {
|
||||
onDelete(metaInstance.metaInstanceId);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
<!-- Corner accents -->
|
||||
<div
|
||||
class="absolute -top-px -left-px w-2 h-2 border-l border-t {statusConfig.cornerClass}"
|
||||
></div>
|
||||
<div
|
||||
class="absolute -top-px -right-px w-2 h-2 border-r border-t {statusConfig.cornerClass}"
|
||||
></div>
|
||||
<div
|
||||
class="absolute -bottom-px -left-px w-2 h-2 border-l border-b {statusConfig.cornerClass}"
|
||||
></div>
|
||||
<div
|
||||
class="absolute -bottom-px -right-px w-2 h-2 border-r border-b {statusConfig.cornerClass}"
|
||||
></div>
|
||||
|
||||
<div
|
||||
class="bg-exo-dark-gray/60 border border-l-2 {statusConfig.borderClass} p-3"
|
||||
>
|
||||
<!-- Header: Status + Delete -->
|
||||
<div class="flex justify-between items-start mb-2 pl-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div
|
||||
class="w-1.5 h-1.5 {statusConfig.dotClass} rounded-full {statusConfig.glowClass} {statusConfig.animate
|
||||
? 'animate-pulse'
|
||||
: ''}"
|
||||
></div>
|
||||
<span
|
||||
class="text-xs font-mono tracking-[0.15em] uppercase {status === 'active'
|
||||
? 'text-green-400'
|
||||
: status === 'error'
|
||||
? 'text-red-400'
|
||||
: 'text-yellow-400'}"
|
||||
>
|
||||
{statusConfig.label}
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={handleDelete}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Model Info -->
|
||||
<div class="pl-2 space-y-1">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">
|
||||
{metaInstance.modelId}
|
||||
</div>
|
||||
|
||||
<!-- Sharding + Runtime badges -->
|
||||
<div class="flex items-center gap-2">
|
||||
<span
|
||||
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
|
||||
>
|
||||
{metaInstance.sharding}
|
||||
</span>
|
||||
<span
|
||||
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
|
||||
>
|
||||
{metaInstance.instanceMeta}
|
||||
</span>
|
||||
{#if metaInstance.minNodes > 1}
|
||||
<span
|
||||
class="inline-flex items-center px-1.5 py-0.5 text-[10px] font-mono tracking-wider uppercase border border-white/10 text-white/50"
|
||||
>
|
||||
{metaInstance.minNodes}+ nodes
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Node Assignments (when active) -->
|
||||
{#if backingNodeIds.length > 0}
|
||||
<div class="flex items-center gap-1.5 mt-1">
|
||||
<svg
|
||||
class="w-3 h-3 text-green-400/70 flex-shrink-0"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M22 12h-4l-3 9L9 3l-3 9H2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-white/60 text-xs font-mono truncate">
|
||||
{backingNodeIds.map((id) => getNodeName(id)).join(", ")}
|
||||
</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Pinned nodes constraint -->
|
||||
{#if metaInstance.nodeIds && metaInstance.nodeIds.length > 0}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3 h-3 text-white/40 flex-shrink-0"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="11" width="18" height="11" rx="2" ry="2" />
|
||||
<path d="M7 11V7a5 5 0 0 1 10 0v4" />
|
||||
</svg>
|
||||
<span class="text-white/40 text-[11px] font-mono">
|
||||
Pinned: {metaInstance.nodeIds
|
||||
.map((id) => getNodeName(id))
|
||||
.join(", ")}
|
||||
</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Error details -->
|
||||
{#if metaInstance.placementError}
|
||||
<div
|
||||
class="mt-1.5 p-2 bg-red-500/5 border border-red-500/15 rounded-sm"
|
||||
>
|
||||
<div class="text-red-400 text-[11px] font-mono leading-relaxed">
|
||||
{metaInstance.placementError}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Retry counter -->
|
||||
{#if metaInstance.consecutiveFailures > 0}
|
||||
<div class="flex items-center gap-1.5 mt-1">
|
||||
<svg
|
||||
class="w-3 h-3 text-yellow-500/60 flex-shrink-0"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<polyline points="23 4 23 10 17 10" />
|
||||
<path d="M20.49 15a9 9 0 1 1-2.12-9.36L23 10" />
|
||||
</svg>
|
||||
<span class="text-yellow-500/60 text-[11px] font-mono">
|
||||
{metaInstance.consecutiveFailures} consecutive
|
||||
failure{metaInstance.consecutiveFailures !== 1 ? "s" : ""}
|
||||
</span>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -11,4 +11,5 @@ export { default as FamilySidebar } from "./FamilySidebar.svelte";
|
||||
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
|
||||
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
|
||||
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
|
||||
export { default as MetaInstanceCard } from "./MetaInstanceCard.svelte";
|
||||
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";
|
||||
|
||||
@@ -72,8 +72,23 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
metaInstanceId?: string | null;
|
||||
}
|
||||
|
||||
export interface MetaInstance {
|
||||
metaInstanceId: string;
|
||||
modelId: string;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instanceMeta: "MlxRing" | "MlxJaccl";
|
||||
minNodes: number;
|
||||
nodeIds: string[] | null;
|
||||
placementError: string | null;
|
||||
consecutiveFailures: number;
|
||||
lastFailureError: string | null;
|
||||
}
|
||||
|
||||
export type MetaInstanceStatus = "active" | "provisioning" | "error";
|
||||
|
||||
// Granular node state types from the new state structure
|
||||
interface RawNodeIdentity {
|
||||
modelId?: string;
|
||||
@@ -223,6 +238,7 @@ interface RawStateResponse {
|
||||
MlxJacclInstance?: Instance;
|
||||
}
|
||||
>;
|
||||
metaInstances?: Record<string, MetaInstance>;
|
||||
runners?: Record<string, unknown>;
|
||||
downloads?: Record<string, unknown[]>;
|
||||
// New granular node state fields
|
||||
@@ -250,11 +266,6 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// Disk usage per node
|
||||
nodeDisk?: Record<
|
||||
string,
|
||||
{ total: { inBytes: number }; available: { inBytes: number } }
|
||||
>;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -538,6 +549,7 @@ class AppStore {
|
||||
// Topology state
|
||||
topologyData = $state<TopologyData | null>(null);
|
||||
instances = $state<Record<string, unknown>>({});
|
||||
metaInstances = $state<Record<string, MetaInstance>>({});
|
||||
runners = $state<Record<string, unknown>>({});
|
||||
downloads = $state<Record<string, unknown[]>>({});
|
||||
nodeDisk = $state<
|
||||
@@ -1273,6 +1285,9 @@ class AppStore {
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
}
|
||||
if (data.metaInstances) {
|
||||
this.metaInstances = data.metaInstances;
|
||||
}
|
||||
if (data.runners) {
|
||||
this.runners = data.runners;
|
||||
}
|
||||
@@ -1298,6 +1313,79 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
async createMetaInstance(
|
||||
modelId: string,
|
||||
sharding: "Pipeline" | "Tensor" = "Pipeline",
|
||||
instanceMeta: "MlxRing" | "MlxJaccl" = "MlxRing",
|
||||
minNodes: number = 1,
|
||||
nodeIds: string[] | null = null,
|
||||
) {
|
||||
try {
|
||||
const response = await fetch("/meta_instance", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model_id: modelId,
|
||||
sharding,
|
||||
instance_meta: instanceMeta,
|
||||
min_nodes: minNodes,
|
||||
node_ids: nodeIds,
|
||||
}),
|
||||
});
|
||||
if (!response.ok) {
|
||||
console.error("Failed to create meta-instance:", response.status);
|
||||
}
|
||||
await this.fetchState();
|
||||
} catch (error) {
|
||||
console.error("Error creating meta-instance:", error);
|
||||
}
|
||||
}
|
||||
|
||||
async deleteMetaInstance(metaInstanceId: string) {
|
||||
try {
|
||||
const response = await fetch(`/meta_instance/${metaInstanceId}`, {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
if (!response.ok) {
|
||||
console.error("Failed to delete meta-instance:", response.status);
|
||||
}
|
||||
await this.fetchState();
|
||||
} catch (error) {
|
||||
console.error("Error deleting meta-instance:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getMetaInstanceStatus(
|
||||
metaInstance: MetaInstance,
|
||||
): MetaInstanceStatus {
|
||||
// Check if any running instance is bound to this meta-instance
|
||||
for (const instanceWrapper of Object.values(this.instances)) {
|
||||
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
|
||||
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
|
||||
if (keys.length !== 1) continue;
|
||||
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]];
|
||||
if (inner && typeof inner === "object" && (inner as Instance).metaInstanceId === metaInstance.metaInstanceId) {
|
||||
return "active";
|
||||
}
|
||||
}
|
||||
if (metaInstance.placementError) return "error";
|
||||
return "provisioning";
|
||||
}
|
||||
|
||||
getMetaInstanceBackingNodes(metaInstance: MetaInstance): string[] {
|
||||
for (const instanceWrapper of Object.values(this.instances)) {
|
||||
if (!instanceWrapper || typeof instanceWrapper !== "object") continue;
|
||||
const keys = Object.keys(instanceWrapper as Record<string, unknown>);
|
||||
if (keys.length !== 1) continue;
|
||||
const inner = (instanceWrapper as Record<string, unknown>)[keys[0]] as Instance;
|
||||
if (inner?.metaInstanceId === metaInstance.metaInstanceId && inner?.shardAssignments?.nodeToRunner) {
|
||||
return Object.keys(inner.shardAssignments.nodeToRunner);
|
||||
}
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
async fetchPlacementPreviews(modelId: string, showLoading = true) {
|
||||
if (!modelId) return;
|
||||
|
||||
@@ -3159,6 +3247,7 @@ export const totalTokens = () => appStore.totalTokens;
|
||||
export const prefillProgress = () => appStore.prefillProgress;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const metaInstances = () => appStore.metaInstances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
@@ -3247,6 +3336,21 @@ export const setChatSidebarVisible = (visible: boolean) =>
|
||||
appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
// Meta-instance actions
|
||||
export const createMetaInstance = (
|
||||
modelId: string,
|
||||
sharding?: "Pipeline" | "Tensor",
|
||||
instanceMeta?: "MlxRing" | "MlxJaccl",
|
||||
minNodes?: number,
|
||||
nodeIds?: string[] | null,
|
||||
) => appStore.createMetaInstance(modelId, sharding, instanceMeta, minNodes, nodeIds);
|
||||
export const deleteMetaInstance = (metaInstanceId: string) =>
|
||||
appStore.deleteMetaInstance(metaInstanceId);
|
||||
export const getMetaInstanceStatus = (metaInstance: MetaInstance) =>
|
||||
appStore.getMetaInstanceStatus(metaInstance);
|
||||
export const getMetaInstanceBackingNodes = (metaInstance: MetaInstance) =>
|
||||
appStore.getMetaInstanceBackingNodes(metaInstance);
|
||||
|
||||
// Node identities (for OS version mismatch detection)
|
||||
export const nodeIdentities = () => appStore.nodeIdentities;
|
||||
|
||||
|
||||
@@ -47,10 +47,14 @@
|
||||
thunderboltBridgeCycles,
|
||||
nodeThunderboltBridge,
|
||||
nodeIdentities,
|
||||
metaInstances,
|
||||
deleteMetaInstance,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview,
|
||||
type MetaInstance,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
import MetaInstanceCard from "$lib/components/MetaInstanceCard.svelte";
|
||||
import { fade, fly } from "svelte/transition";
|
||||
import { cubicInOut } from "svelte/easing";
|
||||
import { onMount } from "svelte";
|
||||
@@ -67,6 +71,8 @@
|
||||
const loadingPreviews = $derived(isLoadingPreviews());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const metaInstanceData = $derived(metaInstances());
|
||||
const metaInstanceCount = $derived(Object.keys(metaInstanceData).length);
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
@@ -858,8 +864,10 @@
|
||||
if (!progress || typeof progress !== "object") return null;
|
||||
|
||||
const prog = progress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(prog.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const completedFiles =
|
||||
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
|
||||
@@ -872,8 +880,8 @@
|
||||
for (const [fileName, fileData] of Object.entries(filesObj)) {
|
||||
if (!fileData || typeof fileData !== "object") continue;
|
||||
const fd = fileData as Record<string, unknown>;
|
||||
const fTotal = getBytes(fd.total);
|
||||
const fDownloaded = getBytes(fd.downloaded);
|
||||
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
|
||||
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
|
||||
files.push({
|
||||
name: fileName,
|
||||
totalBytes: fTotal,
|
||||
@@ -1262,6 +1270,7 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
@@ -3053,6 +3062,39 @@
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Meta-Instances Panel -->
|
||||
{#if metaInstanceCount > 0}
|
||||
<div class="p-4 flex-shrink-0 border-t border-exo-yellow/10">
|
||||
<!-- Panel Header -->
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div
|
||||
class="w-2 h-2 border border-purple-400/60 rotate-45"
|
||||
></div>
|
||||
<h3
|
||||
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
|
||||
>
|
||||
Meta-Instances
|
||||
</h3>
|
||||
<div
|
||||
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
|
||||
></div>
|
||||
<span class="text-[10px] text-white/40 font-mono"
|
||||
>{metaInstanceCount}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="space-y-3">
|
||||
{#each Object.entries(metaInstanceData) as [id, mi]}
|
||||
<MetaInstanceCard
|
||||
metaInstance={mi}
|
||||
onDelete={(metaInstanceId) =>
|
||||
deleteMetaInstance(metaInstanceId)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Models Panel - Scrollable -->
|
||||
<div class="p-4 flex-1 overflow-y-auto">
|
||||
<!-- Panel Header -->
|
||||
@@ -3875,6 +3917,34 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Meta-Instances Section (chat sidebar) -->
|
||||
{#if metaInstanceCount > 0}
|
||||
<div class="p-4 border-t border-exo-yellow/10">
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div
|
||||
class="w-2 h-2 border border-purple-400/60 rotate-45"
|
||||
></div>
|
||||
<h3
|
||||
class="text-xs text-purple-400 font-mono tracking-[0.2em] uppercase"
|
||||
>
|
||||
Meta-Instances
|
||||
</h3>
|
||||
<div
|
||||
class="flex-1 h-px bg-gradient-to-r from-purple-400/30 to-transparent"
|
||||
></div>
|
||||
</div>
|
||||
<div class="space-y-3">
|
||||
{#each Object.entries(metaInstanceData) as [id, mi]}
|
||||
<MetaInstanceCard
|
||||
metaInstance={mi}
|
||||
onDelete={(metaInstanceId) =>
|
||||
deleteMetaInstance(metaInstanceId)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</aside>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -74,6 +74,7 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
@@ -230,14 +231,23 @@
|
||||
undefined;
|
||||
let cell: CellStatus;
|
||||
if (tag === "DownloadCompleted") {
|
||||
const totalBytes = getBytes(payload.total);
|
||||
const totalBytes = getBytes(
|
||||
payload.total_bytes ?? payload.totalBytes,
|
||||
);
|
||||
cell = { kind: "completed", totalBytes, modelDirectory };
|
||||
} else if (tag === "DownloadOngoing") {
|
||||
const rawProgress =
|
||||
payload.download_progress ?? payload.downloadProgress ?? {};
|
||||
const prog = rawProgress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(prog.total ?? payload.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const totalBytes = getBytes(
|
||||
prog.total_bytes ??
|
||||
prog.totalBytes ??
|
||||
payload.total_bytes ??
|
||||
payload.totalBytes,
|
||||
);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const etaMs =
|
||||
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;
|
||||
|
||||
@@ -80,7 +80,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total=progress.total,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -203,7 +203,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total=initial_progress.total,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -332,13 +332,13 @@ class DownloadCoordinator:
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total=progress.total,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_this_session.in_bytes == 0:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
|
||||
@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
|
||||
repo_file_download_progress: RepoFileDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
downloaded=repo_file_download_progress.downloaded,
|
||||
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total=repo_file_download_progress.total,
|
||||
downloaded_bytes=repo_file_download_progress.downloaded,
|
||||
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total_bytes=repo_file_download_progress.total,
|
||||
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
|
||||
total_files=1,
|
||||
speed=repo_file_download_progress.speed,
|
||||
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
repo_download_progress: RepoDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
total=repo_download_progress.total,
|
||||
downloaded=repo_download_progress.downloaded,
|
||||
downloaded_this_session=repo_download_progress.downloaded_this_session,
|
||||
total_bytes=repo_download_progress.total_bytes,
|
||||
downloaded_bytes=repo_download_progress.downloaded_bytes,
|
||||
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
|
||||
completed_files=repo_download_progress.completed_files,
|
||||
total_files=repo_download_progress.total_files,
|
||||
speed=repo_download_progress.overall_speed,
|
||||
@@ -578,20 +578,19 @@ def calculate_repo_progress(
|
||||
file_progress: dict[str, RepoFileDownloadProgress],
|
||||
all_start_time: float,
|
||||
) -> RepoDownloadProgress:
|
||||
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
|
||||
all_downloaded = sum(
|
||||
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
|
||||
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
|
||||
all_downloaded_bytes = sum(
|
||||
(p.downloaded.in_bytes for p in file_progress.values()), 0
|
||||
)
|
||||
all_downloaded_this_session = sum(
|
||||
(p.downloaded_this_session for p in file_progress.values()),
|
||||
Memory.from_bytes(0),
|
||||
all_downloaded_bytes_this_session = sum(
|
||||
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
|
||||
)
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = (
|
||||
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
)
|
||||
all_eta = (
|
||||
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
|
||||
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
|
||||
if all_speed > 0
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
@@ -610,9 +609,11 @@ def calculate_repo_progress(
|
||||
[p for p in file_progress.values() if p.downloaded == p.total]
|
||||
),
|
||||
total_files=len(file_progress),
|
||||
downloaded=all_downloaded,
|
||||
downloaded_this_session=all_downloaded_this_session,
|
||||
total=all_total,
|
||||
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(
|
||||
all_downloaded_bytes_this_session
|
||||
),
|
||||
total_bytes=Memory.from_bytes(all_total_bytes),
|
||||
overall_speed=all_speed,
|
||||
overall_eta=all_eta,
|
||||
status=status,
|
||||
|
||||
@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded=Memory.from_bytes(0),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(0),
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.ollama_api import (
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaDoneReason,
|
||||
OllamaGenerateRequest,
|
||||
OllamaGenerateResponse,
|
||||
OllamaMessage,
|
||||
OllamaToolCall,
|
||||
OllamaToolFunction,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _map_done_reason(
|
||||
finish_reason: str | None,
|
||||
) -> OllamaDoneReason | None:
|
||||
if finish_reason is None:
|
||||
return None
|
||||
if finish_reason == "stop":
|
||||
return "stop"
|
||||
if finish_reason == "length":
|
||||
return "length"
|
||||
if finish_reason in ("tool_calls", "function_call"):
|
||||
return "tool_call"
|
||||
if finish_reason == "error":
|
||||
return "error"
|
||||
return "stop"
|
||||
|
||||
|
||||
def _try_parse_json(value: str) -> dict[str, Any] | str:
|
||||
try:
|
||||
return json.loads(value) # type: ignore
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
|
||||
tool_calls: list[OllamaToolCall] = []
|
||||
for index, tool in enumerate(chunk.tool_calls):
|
||||
# tool.arguments is always str; try to parse as JSON dict for Ollama format
|
||||
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
|
||||
tool_calls.append(
|
||||
OllamaToolCall(
|
||||
id=tool.id,
|
||||
type="function",
|
||||
function=OllamaToolFunction(
|
||||
name=tool.name, arguments=arguments, index=index
|
||||
),
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _get_usage(
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
) -> tuple[int | None, int | None]:
|
||||
"""Extract (prompt_eval_count, eval_count) from a chunk."""
|
||||
if chunk.usage is not None:
|
||||
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
|
||||
if chunk.stats is not None:
|
||||
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
|
||||
return (None, None)
|
||||
|
||||
|
||||
def ollama_request_to_text_generation(
|
||||
request: OllamaChatRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
"""Convert Ollama chat request to exo's internal text generation format."""
|
||||
instructions: str | None = None
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
tool_message_index = 0
|
||||
|
||||
for msg in request.messages:
|
||||
content = msg.content or ""
|
||||
|
||||
if msg.role == "system":
|
||||
if instructions is None:
|
||||
instructions = content
|
||||
else:
|
||||
instructions = f"{instructions}\n{content}"
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant") and (
|
||||
msg.content is not None or msg.thinking is not None or msg.tool_calls
|
||||
):
|
||||
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||
|
||||
dumped: dict[str, Any] = {"role": msg.role, "content": content}
|
||||
if msg.thinking is not None:
|
||||
dumped["thinking"] = msg.thinking
|
||||
if msg.tool_calls is not None:
|
||||
tool_calls_list: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
function: dict[str, Any] = {
|
||||
"name": tc.function.name,
|
||||
"arguments": (
|
||||
json.dumps(tc.function.arguments)
|
||||
if isinstance(tc.function.arguments, dict)
|
||||
else tc.function.arguments
|
||||
),
|
||||
}
|
||||
if tc.function.index is not None:
|
||||
function["index"] = tc.function.index
|
||||
tool_call: dict[str, Any] = {"function": function}
|
||||
if tc.id is not None:
|
||||
tool_call["id"] = tc.id
|
||||
if tc.type is not None:
|
||||
tool_call["type"] = tc.type
|
||||
tool_calls_list.append(tool_call)
|
||||
dumped["tool_calls"] = tool_calls_list
|
||||
if msg.name is not None:
|
||||
dumped["name"] = msg.name
|
||||
if msg.role == "tool":
|
||||
tool_message_index += 1
|
||||
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
|
||||
dumped["tool_call_id"] = tool_call_id
|
||||
if msg.tool_name is not None:
|
||||
dumped["tool_name"] = msg.tool_name
|
||||
chat_template_messages.append(dumped)
|
||||
|
||||
options = request.options
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")],
|
||||
instructions=instructions,
|
||||
max_output_tokens=options.num_predict if options else None,
|
||||
temperature=options.temperature if options else None,
|
||||
top_p=options.top_p if options else None,
|
||||
top_k=options.top_k if options else None,
|
||||
stop=options.stop if options else None,
|
||||
seed=options.seed if options else None,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
enable_thinking=request.think,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def generate_ollama_chat_stream(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
|
||||
thinking_parts: list[str] = []
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
error_response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant", content=chunk.error_message
|
||||
),
|
||||
done=True,
|
||||
done_reason="error",
|
||||
)
|
||||
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case ToolCallChunk():
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=_build_tool_calls(chunk),
|
||||
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||
),
|
||||
done=True,
|
||||
done_reason="tool_call",
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
done = chunk.finish_reason is not None
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant", content="", thinking=chunk.text
|
||||
),
|
||||
done=False,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
elif done:
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=chunk.text,
|
||||
),
|
||||
done=True,
|
||||
done_reason=_map_done_reason(chunk.finish_reason),
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
else:
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(role="assistant", content=chunk.text),
|
||||
done=False,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
|
||||
if done:
|
||||
return
|
||||
|
||||
|
||||
async def collect_ollama_chat_response(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect streaming chunks into a single non-streaming Ollama response.
|
||||
|
||||
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
|
||||
StreamingResponse cancellation handling.
|
||||
"""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[OllamaToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
eval_count: int | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
raise ValueError(chunk.error_message or "Internal server error")
|
||||
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
tool_calls.extend(_build_tool_calls(chunk))
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||
assert model is not None
|
||||
|
||||
yield OllamaChatResponse(
|
||||
model=model,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
thinking=combined_thinking,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
done=True,
|
||||
done_reason=_map_done_reason(finish_reason),
|
||||
prompt_eval_count=prompt_eval_count,
|
||||
eval_count=eval_count,
|
||||
).model_dump_json(exclude_none=True)
|
||||
return
|
||||
|
||||
|
||||
# ── /api/generate ──
|
||||
|
||||
|
||||
def ollama_generate_request_to_text_generation(
|
||||
request: OllamaGenerateRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
"""Convert Ollama generate request to exo's internal text generation format."""
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
if request.system:
|
||||
chat_template_messages.append({"role": "system", "content": request.system})
|
||||
chat_template_messages.append({"role": "user", "content": request.prompt})
|
||||
|
||||
options = request.options
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=[InputMessage(role="user", content=request.prompt)],
|
||||
instructions=request.system,
|
||||
max_output_tokens=options.num_predict if options else None,
|
||||
temperature=options.temperature if options else None,
|
||||
top_p=options.top_p if options else None,
|
||||
top_k=options.top_k if options else None,
|
||||
stop=options.stop if options else None,
|
||||
seed=options.seed if options else None,
|
||||
stream=request.stream,
|
||||
enable_thinking=request.think,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def generate_ollama_generate_stream(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
|
||||
thinking_parts: list[str] = []
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
done=True,
|
||||
done_reason="error",
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case ToolCallChunk():
|
||||
# generate endpoint doesn't support tools; emit as done
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
done = chunk.finish_reason is not None
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
thinking=chunk.text,
|
||||
done=False,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
elif done:
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response=chunk.text,
|
||||
done=True,
|
||||
done_reason=_map_done_reason(chunk.finish_reason),
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
else:
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response=chunk.text,
|
||||
done=False,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
|
||||
if done:
|
||||
return
|
||||
|
||||
|
||||
async def collect_ollama_generate_response(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect chunks into a single non-streaming /api/generate response."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
eval_count: int | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
case ErrorChunk():
|
||||
raise ValueError(chunk.error_message or "Internal server error")
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
assert model is not None
|
||||
yield OllamaGenerateResponse(
|
||||
model=model,
|
||||
response="".join(text_parts),
|
||||
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||
done=True,
|
||||
done_reason=_map_done_reason(finish_reason),
|
||||
prompt_eval_count=prompt_eval_count,
|
||||
eval_count=eval_count,
|
||||
).model_dump_json(exclude_none=True)
|
||||
return
|
||||
@@ -32,14 +32,6 @@ from exo.master.adapters.claude import (
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.master.adapters.ollama import (
|
||||
collect_ollama_chat_response,
|
||||
collect_ollama_generate_response,
|
||||
generate_ollama_chat_stream,
|
||||
generate_ollama_generate_stream,
|
||||
ollama_generate_request_to_text_generation,
|
||||
ollama_request_to_text_generation,
|
||||
)
|
||||
from exo.master.adapters.responses import (
|
||||
collect_responses_response,
|
||||
generate_responses_stream,
|
||||
@@ -149,19 +141,6 @@ from exo.shared.types.events import (
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.ollama_api import (
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaGenerateRequest,
|
||||
OllamaGenerateResponse,
|
||||
OllamaModelDetails,
|
||||
OllamaModelTag,
|
||||
OllamaPsModel,
|
||||
OllamaPsResponse,
|
||||
OllamaShowRequest,
|
||||
OllamaShowResponse,
|
||||
OllamaTagsResponse,
|
||||
)
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
@@ -321,20 +300,6 @@ class API:
|
||||
self.app.get("/images/{image_id}")(self.get_image)
|
||||
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
||||
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
||||
# Ollama API — health checks (must be before static files mount)
|
||||
self.app.head("/")(self._ollama_root)
|
||||
self.app.head("/api/version")(self.ollama_version)
|
||||
# Ollama API
|
||||
self.app.post("/api/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/api/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/v1/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/generate", response_model=None)(self.ollama_generate)
|
||||
self.app.get("/api/tags")(self.ollama_tags)
|
||||
self.app.get("/api/api/tags")(self.ollama_tags)
|
||||
self.app.get("/api/v1/tags")(self.ollama_tags)
|
||||
self.app.post("/api/show")(self.ollama_show)
|
||||
self.app.get("/api/ps")(self.ollama_ps)
|
||||
self.app.get("/api/version")(self.ollama_version)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(self.stream_events)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
@@ -1328,158 +1293,6 @@ class API:
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def _ollama_root(self) -> JSONResponse:
|
||||
"""Respond to HEAD / from Ollama CLI connectivity checks."""
|
||||
return JSONResponse(content="Ollama is running")
|
||||
|
||||
async def ollama_chat(
|
||||
self, request: Request
|
||||
) -> OllamaChatResponse | StreamingResponse:
|
||||
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
|
||||
body = await request.body()
|
||||
payload = OllamaChatRequest.model_validate_json(body)
|
||||
task_params = ollama_request_to_text_generation(payload)
|
||||
resolved_model = await self._resolve_and_validate_text_model(
|
||||
ModelId(task_params.model)
|
||||
)
|
||||
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_ollama_chat_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_ollama_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def ollama_generate(
|
||||
self, request: Request
|
||||
) -> OllamaGenerateResponse | StreamingResponse:
|
||||
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
|
||||
body = await request.body()
|
||||
payload = OllamaGenerateRequest.model_validate_json(body)
|
||||
task_params = ollama_generate_request_to_text_generation(payload)
|
||||
resolved_model = await self._resolve_and_validate_text_model(
|
||||
ModelId(task_params.model)
|
||||
)
|
||||
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_ollama_generate_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_ollama_generate_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def ollama_tags(self) -> OllamaTagsResponse:
|
||||
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
|
||||
|
||||
def none_if_empty(value: str) -> str | None:
|
||||
return value or None
|
||||
|
||||
downloaded_model_ids: set[str] = set()
|
||||
for node_downloads in self.state.downloads.values():
|
||||
for dl in node_downloads:
|
||||
if isinstance(dl, DownloadCompleted):
|
||||
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
|
||||
|
||||
cards = [
|
||||
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
|
||||
]
|
||||
|
||||
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
return OllamaTagsResponse(
|
||||
models=[
|
||||
OllamaModelTag(
|
||||
name=str(card.model_id),
|
||||
model=str(card.model_id),
|
||||
modified_at=now,
|
||||
size=card.storage_size.in_bytes,
|
||||
digest="sha256:000000000000",
|
||||
details=OllamaModelDetails(
|
||||
family=none_if_empty(card.family),
|
||||
quantization_level=none_if_empty(card.quantization),
|
||||
),
|
||||
)
|
||||
for card in cards
|
||||
]
|
||||
)
|
||||
|
||||
async def ollama_show(self, request: Request) -> OllamaShowResponse:
|
||||
"""Returns model information in Ollama show format."""
|
||||
body = await request.body()
|
||||
payload = OllamaShowRequest.model_validate_json(body)
|
||||
try:
|
||||
card = await ModelCard.load(ModelId(payload.name))
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model not found: {payload.name}"
|
||||
) from exc
|
||||
|
||||
return OllamaShowResponse(
|
||||
details=OllamaModelDetails(
|
||||
family=card.family or None,
|
||||
quantization_level=card.quantization or None,
|
||||
),
|
||||
)
|
||||
|
||||
async def ollama_ps(self) -> OllamaPsResponse:
|
||||
"""Returns list of running models (active instances)."""
|
||||
models: list[OllamaPsModel] = []
|
||||
seen: set[str] = set()
|
||||
for instance in self.state.instances.values():
|
||||
model_id = str(instance.shard_assignments.model_id)
|
||||
if model_id in seen:
|
||||
continue
|
||||
seen.add(model_id)
|
||||
models.append(
|
||||
OllamaPsModel(
|
||||
name=model_id,
|
||||
model=model_id,
|
||||
size=0,
|
||||
)
|
||||
)
|
||||
return OllamaPsResponse(models=models)
|
||||
|
||||
async def ollama_version(self) -> dict[str, str]:
|
||||
"""Returns version information for Ollama API compatibility."""
|
||||
return {"version": "exo v1.0"}
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
@@ -1509,7 +1322,7 @@ class API:
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
tags=[],
|
||||
storage_size_megabytes=card.storage_size.in_mb,
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
supports_tensor=card.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
is_custom=is_custom_card(card.model_id),
|
||||
|
||||
@@ -102,21 +102,22 @@ def _allocate_and_validate_layers(
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=model_card.n_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available / total_memory for node_id in node_ids
|
||||
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
|
||||
for node_id in node_ids
|
||||
],
|
||||
)
|
||||
|
||||
total_storage = model_card.storage_size
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
total_layers = model_card.n_layers
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
required_memory = (total_storage * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"Node {i} ({node_id}) has insufficient memory: "
|
||||
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory.in_gb:.2f} GB available"
|
||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
return layer_allocations
|
||||
@@ -341,7 +342,6 @@ def _find_ip_prioritised(
|
||||
other_node_id: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
ring: bool,
|
||||
) -> str | None:
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
@@ -354,27 +354,13 @@ def _find_ip_prioritised(
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
|
||||
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
|
||||
# TODO: Profile and get actual connection speeds.
|
||||
if ring:
|
||||
priority = {
|
||||
"thunderbolt": 0,
|
||||
"maybe_ethernet": 1,
|
||||
"ethernet": 2,
|
||||
"wifi": 3,
|
||||
"unknown": 4,
|
||||
}
|
||||
|
||||
# RDMA prefers ethernet coordinator
|
||||
else:
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
@@ -414,7 +400,7 @@ def get_mlx_ring_hosts_by_node(
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(
|
||||
node_id, other_node_id, cycle_digraph, node_network, ring=True
|
||||
node_id, other_node_id, cycle_digraph, node_network
|
||||
)
|
||||
if connection_ip is None:
|
||||
raise ValueError(
|
||||
@@ -445,9 +431,7 @@ def get_mlx_jaccl_coordinators(
|
||||
if n == coordinator:
|
||||
return "0.0.0.0"
|
||||
|
||||
ip = _find_ip_prioritised(
|
||||
n, coordinator, cycle_digraph, node_network, ring=False
|
||||
)
|
||||
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
|
||||
if ip is not None:
|
||||
return ip
|
||||
|
||||
|
||||
@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
|
||||
):
|
||||
# arrange
|
||||
model_card.n_layers = total_layers
|
||||
model_card.storage_size = Memory.from_bytes(
|
||||
sum(available_memory)
|
||||
model_card.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
# arrange
|
||||
topology = Topology()
|
||||
model_card.n_layers = 12
|
||||
model_card.storage_size = Memory.from_bytes(1500)
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total=Memory(),
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total=Memory(),
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total=Memory(),
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from math import ceil
|
||||
from typing import Self, overload
|
||||
from typing import Self
|
||||
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class Memory(FrozenModel):
|
||||
class Memory(CamelCaseModel):
|
||||
in_bytes: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -33,22 +33,12 @@ class Memory(FrozenModel):
|
||||
return cls(in_bytes=round(val * 1024))
|
||||
|
||||
@property
|
||||
def in_mb(self) -> int:
|
||||
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
|
||||
return round(self.in_bytes / (1024**2))
|
||||
|
||||
@in_mb.setter
|
||||
def in_mb(self, val: int):
|
||||
"""Set the megabytes for this memory."""
|
||||
self.in_bytes = val * (1024**2)
|
||||
|
||||
@property
|
||||
def in_float_mb(self) -> float:
|
||||
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
|
||||
def in_mb(self) -> float:
|
||||
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
|
||||
return self.in_bytes / (1024**2)
|
||||
|
||||
@in_float_mb.setter
|
||||
def in_float_mb(self, val: float):
|
||||
@in_mb.setter
|
||||
def in_mb(self, val: float):
|
||||
"""Set the megabytes for this memory, rounded to the nearest byte."""
|
||||
self.in_bytes = round(val * (1024**2))
|
||||
|
||||
@@ -67,85 +57,17 @@ class Memory(FrozenModel):
|
||||
"""The approximate gigabytes this memory represents."""
|
||||
return self.in_bytes / (1024**3)
|
||||
|
||||
def __add__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
return NotImplemented
|
||||
def __add__(self, other: "Memory") -> "Memory":
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
|
||||
def __radd__(self, other: object) -> "Memory":
|
||||
if other == 0:
|
||||
return self
|
||||
return NotImplemented
|
||||
def __lt__(self, other: Self) -> bool:
|
||||
return self.in_bytes < other.in_bytes
|
||||
|
||||
def __sub__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes - other.in_bytes)
|
||||
return NotImplemented
|
||||
def __le__(self, other: Self) -> bool:
|
||||
return self.in_bytes <= other.in_bytes
|
||||
|
||||
def __mul__(self, other: int | float):
|
||||
return Memory.from_bytes(round(self.in_bytes * other))
|
||||
def __gt__(self, other: Self) -> bool:
|
||||
return self.in_bytes > other.in_bytes
|
||||
|
||||
def __rmul__(self, other: int | float):
|
||||
return self * other
|
||||
|
||||
@overload
|
||||
def __truediv__(self, other: "Memory") -> float: ...
|
||||
@overload
|
||||
def __truediv__(self, other: int) -> "Memory": ...
|
||||
@overload
|
||||
def __truediv__(self, other: float) -> "Memory": ...
|
||||
def __truediv__(self, other: object) -> "Memory | float":
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes / other.in_bytes
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(round(self.in_bytes / other))
|
||||
return NotImplemented
|
||||
|
||||
def __floordiv__(self, other: object) -> "Memory":
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(int(self.in_bytes // other))
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes < other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes <= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes > other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes >= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes == other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Memory.from_bytes({self.in_bytes})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.in_gb > 2:
|
||||
val = self.in_gb
|
||||
unit = "GiB"
|
||||
elif self.in_mb > 2:
|
||||
val = self.in_mb
|
||||
unit = "MiB"
|
||||
elif self.in_kb > 3:
|
||||
val = self.in_kb
|
||||
unit = "KiB"
|
||||
else:
|
||||
val = self.in_bytes
|
||||
unit = "B"
|
||||
|
||||
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
|
||||
def __ge__(self, other: Self) -> bool:
|
||||
return self.in_bytes >= other.in_bytes
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
|
||||
OllamaRole = Literal["system", "user", "assistant", "tool"]
|
||||
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
|
||||
|
||||
|
||||
class OllamaToolFunction(BaseModel, frozen=True):
|
||||
name: str
|
||||
arguments: dict[str, Any] | str
|
||||
index: int | None = None
|
||||
|
||||
|
||||
class OllamaToolCall(BaseModel, frozen=True):
|
||||
id: str | None = None
|
||||
type: Literal["function"] | None = None
|
||||
function: OllamaToolFunction
|
||||
|
||||
|
||||
class OllamaMessage(BaseModel, frozen=True):
|
||||
role: OllamaRole
|
||||
content: str | None = None
|
||||
thinking: str | None = None
|
||||
tool_calls: list[OllamaToolCall] | None = None
|
||||
name: str | None = None
|
||||
tool_name: str | None = None
|
||||
images: list[str] | None = None
|
||||
|
||||
|
||||
class OllamaOptions(BaseModel, frozen=True):
|
||||
num_predict: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
class OllamaChatRequest(BaseModel, frozen=True):
|
||||
model: ModelId
|
||||
messages: list[OllamaMessage]
|
||||
stream: bool = True
|
||||
options: OllamaOptions | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
format: Literal["json"] | dict[str, Any] | None = None
|
||||
keep_alive: str | int | None = None
|
||||
think: bool | None = None
|
||||
|
||||
|
||||
class OllamaGenerateRequest(BaseModel, frozen=True):
|
||||
model: ModelId
|
||||
prompt: str = ""
|
||||
system: str | None = None
|
||||
stream: bool = True
|
||||
options: OllamaOptions | None = None
|
||||
format: Literal["json"] | dict[str, Any] | None = None
|
||||
keep_alive: str | int | None = None
|
||||
think: bool | None = None
|
||||
raw: bool = False
|
||||
|
||||
|
||||
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
|
||||
model: str
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
)
|
||||
response: str
|
||||
thinking: str | None = None
|
||||
done: bool
|
||||
done_reason: OllamaDoneReason | None = None
|
||||
total_duration: int | None = None
|
||||
load_duration: int | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
prompt_eval_duration: int | None = None
|
||||
eval_count: int | None = None
|
||||
eval_duration: int | None = None
|
||||
|
||||
|
||||
class OllamaShowRequest(BaseModel, frozen=True):
|
||||
name: str
|
||||
verbose: bool | None = None
|
||||
|
||||
|
||||
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
|
||||
model: str
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
)
|
||||
message: OllamaMessage
|
||||
done: bool
|
||||
done_reason: OllamaDoneReason | None = None
|
||||
total_duration: int | None = None
|
||||
load_duration: int | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
prompt_eval_duration: int | None = None
|
||||
eval_count: int | None = None
|
||||
eval_duration: int | None = None
|
||||
|
||||
|
||||
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
|
||||
format: str | None = None
|
||||
family: str | None = None
|
||||
parameter_size: str | None = None
|
||||
quantization_level: str | None = None
|
||||
|
||||
|
||||
class OllamaModelTag(BaseModel, frozen=True, strict=True):
|
||||
name: str
|
||||
model: str | None = None
|
||||
modified_at: str | None = None
|
||||
size: int | None = None
|
||||
digest: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
|
||||
|
||||
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
|
||||
models: list[OllamaModelTag]
|
||||
|
||||
|
||||
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
|
||||
modelfile: str | None = None
|
||||
parameters: str | None = None
|
||||
template: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
model_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OllamaPsModel(BaseModel, frozen=True, strict=True):
|
||||
name: str
|
||||
model: str
|
||||
size: int
|
||||
digest: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
expires_at: str | None = None
|
||||
size_vram: int | None = None
|
||||
|
||||
|
||||
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
|
||||
models: list[OllamaPsModel]
|
||||
@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class DownloadProgressData(CamelCaseModel):
|
||||
total: Memory
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total_bytes: Memory
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
|
||||
completed_files: int
|
||||
total_files: int
|
||||
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
total: Memory
|
||||
total_bytes: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
@@ -192,13 +192,7 @@ class MpReceiver[T]:
|
||||
try:
|
||||
return self.receive_nowait()
|
||||
except WouldBlock:
|
||||
try:
|
||||
item = self._state.buffer.get()
|
||||
except (TypeError, OSError):
|
||||
# Queue pipe can get closed while we are blocked on get().
|
||||
# The underlying connection._handle becomes None, causing
|
||||
# TypeError in read(handle, remaining).
|
||||
raise ClosedResourceError from None
|
||||
item = self._state.buffer.get()
|
||||
if isinstance(item, _MpEndOfStream):
|
||||
self.close()
|
||||
raise EndOfStream from None
|
||||
|
||||
@@ -108,7 +108,7 @@ async def check_reachable(
|
||||
await send.send((target_ip, expected_node_id))
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
for node_id in topology.list_nodes():
|
||||
|
||||
@@ -166,7 +166,7 @@ def generate_image(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory = Memory.from_bytes(mx.get_peak_memory())
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
@@ -175,7 +175,7 @@ def generate_image(
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=peak_memory,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
|
||||
@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
|
||||
@@ -232,11 +232,11 @@ def shard_and_load(
|
||||
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
model_size = get_weights_size(shard_metadata)
|
||||
timeout_seconds = base_timeout + model_size.in_gb
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size.in_gb:.1f}GB)"
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
match shard_metadata:
|
||||
@@ -642,17 +642,18 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
max_rec_size = Memory.from_bytes(
|
||||
int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
)
|
||||
if model_size > 0.9 * max_rec_size:
|
||||
model_bytes = model_size.in_bytes
|
||||
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if model_bytes > 0.9 * max_rec_size:
|
||||
model_mb = model_bytes // 2**20
|
||||
max_rec_mb = max_rec_size // 2**20
|
||||
logger.warning(
|
||||
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
|
||||
f"Generating with a model that requires {model_mb} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_mb} "
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||
)
|
||||
mx.set_wired_limit(max_rec_size.in_bytes)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
@@ -588,13 +588,6 @@ def main(
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
if not TYPE_CHECKING:
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
@@ -619,8 +612,12 @@ def main(
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
|
||||
|
||||
|
||||
@@ -100,8 +100,8 @@ class RunnerSupervisor:
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -180,7 +180,6 @@ class RunnerSupervisor:
|
||||
await self._check_runner(e)
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
self._event_sender.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
@@ -209,15 +208,10 @@ class RunnerSupervisor:
|
||||
|
||||
logger.opt(exception=e).error(f"Runner terminated ({cause})")
|
||||
|
||||
try:
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
logger.warning(
|
||||
"Event sender already closed, unable to report runner failure"
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
|
||||
@@ -90,10 +90,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -134,7 +138,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
@@ -181,7 +187,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
@@ -199,10 +207,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user