mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-26 11:16:14 -05:00
Compare commits
8 Commits
ciaran/tra
...
leo/add-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0955966b2a | ||
|
|
20ea13f047 | ||
|
|
4ff1578140 | ||
|
|
a5873bc1fd | ||
|
|
dc1ce2a2cf | ||
|
|
ff57b00dc6 | ||
|
|
d3222c498a | ||
|
|
2f719d62a7 |
@@ -249,7 +249,8 @@ class ChunkedKVCache(KVCache):
|
||||
...
|
||||
|
||||
class CacheList(_BaseCache):
|
||||
def __init__(self, *caches) -> None: ...
|
||||
caches: tuple[_BaseCache, ...]
|
||||
def __init__(self, *caches: _BaseCache) -> None: ...
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
|
||||
@@ -170,5 +170,30 @@
|
||||
{/if}
|
||||
Downloads
|
||||
</a>
|
||||
<a
|
||||
href="/#/settings"
|
||||
class="text-sm text-white/70 hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
|
||||
title="Settings"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
|
||||
/>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
|
||||
/>
|
||||
</svg>
|
||||
Settings
|
||||
</a>
|
||||
</nav>
|
||||
</header>
|
||||
|
||||
@@ -3158,23 +3158,6 @@ class AppStore {
|
||||
return (await response.json()) as TraceStatsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete traces by task IDs
|
||||
*/
|
||||
async deleteTraces(
|
||||
taskIds: string[],
|
||||
): Promise<{ deleted: string[]; notFound: string[] }> {
|
||||
const response = await fetch("/v1/traces/delete", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ taskIds }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to delete traces: ${response.status}`);
|
||||
}
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the URL for the raw trace file (for Perfetto)
|
||||
*/
|
||||
@@ -3318,5 +3301,3 @@ export const fetchTraceStats = (taskId: string) =>
|
||||
appStore.fetchTraceStats(taskId);
|
||||
export const getTraceRawUrl = (taskId: string) =>
|
||||
appStore.getTraceRawUrl(taskId);
|
||||
export const deleteTraces = (taskIds: string[]) =>
|
||||
appStore.deleteTraces(taskIds);
|
||||
|
||||
87
dashboard/src/lib/stores/settings.svelte.ts
Normal file
87
dashboard/src/lib/stores/settings.svelte.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
/**
|
||||
* SettingsStore - Manages exo runtime settings via the /settings API.
|
||||
*/
|
||||
|
||||
export interface MemorySettings {
|
||||
oom_prevention: boolean;
|
||||
memory_threshold: number;
|
||||
memory_floor_gb: number;
|
||||
}
|
||||
|
||||
export interface GenerationSettings {
|
||||
prefill_step_size: number;
|
||||
max_tokens: number;
|
||||
kv_cache_bits: 4 | 8 | null;
|
||||
}
|
||||
|
||||
export interface ExoSettings {
|
||||
memory: MemorySettings;
|
||||
generation: GenerationSettings;
|
||||
}
|
||||
|
||||
function defaultSettings(): ExoSettings {
|
||||
return {
|
||||
memory: {
|
||||
oom_prevention: false,
|
||||
memory_threshold: 0.8,
|
||||
memory_floor_gb: 5.0,
|
||||
},
|
||||
generation: {
|
||||
prefill_step_size: 4096,
|
||||
max_tokens: 32168,
|
||||
kv_cache_bits: null,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
class SettingsStore {
|
||||
settings = $state<ExoSettings>(defaultSettings());
|
||||
loading = $state(false);
|
||||
error = $state<string | null>(null);
|
||||
|
||||
async load(): Promise<void> {
|
||||
this.loading = true;
|
||||
this.error = null;
|
||||
try {
|
||||
const response = await fetch("/settings");
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch settings: ${response.status}`);
|
||||
}
|
||||
this.settings = (await response.json()) as ExoSettings;
|
||||
} catch (err) {
|
||||
console.error("Failed to load settings:", err);
|
||||
this.error = err instanceof Error ? err.message : "Unknown error";
|
||||
} finally {
|
||||
this.loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
async save(updated: ExoSettings): Promise<boolean> {
|
||||
this.loading = true;
|
||||
this.error = null;
|
||||
try {
|
||||
const response = await fetch("/settings", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(updated),
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to save settings: ${response.status}`);
|
||||
}
|
||||
this.settings = (await response.json()) as ExoSettings;
|
||||
return true;
|
||||
} catch (err) {
|
||||
console.error("Failed to save settings:", err);
|
||||
this.error = err instanceof Error ? err.message : "Unknown error";
|
||||
return false;
|
||||
} finally {
|
||||
this.loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
resetToDefaults(): ExoSettings {
|
||||
return defaultSettings();
|
||||
}
|
||||
}
|
||||
|
||||
export const settingsStore = new SettingsStore();
|
||||
193
dashboard/src/routes/settings/+page.svelte
Normal file
193
dashboard/src/routes/settings/+page.svelte
Normal file
@@ -0,0 +1,193 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import { fade } from "svelte/transition";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
import { settingsStore, type ExoSettings } from "$lib/stores/settings.svelte";
|
||||
import { addToast } from "$lib/stores/toast.svelte";
|
||||
|
||||
let draft = $state<ExoSettings | null>(null);
|
||||
const loading = $derived(settingsStore.loading);
|
||||
|
||||
onMount(async () => {
|
||||
await settingsStore.load();
|
||||
draft = structuredClone(settingsStore.settings);
|
||||
});
|
||||
|
||||
async function handleSave() {
|
||||
if (!draft) return;
|
||||
const ok = await settingsStore.save(draft);
|
||||
if (ok) {
|
||||
addToast({ type: "success", message: "Settings saved" });
|
||||
} else {
|
||||
addToast({ type: "error", message: settingsStore.error ?? "Failed to save settings" });
|
||||
}
|
||||
}
|
||||
|
||||
function handleReset() {
|
||||
draft = settingsStore.resetToDefaults();
|
||||
}
|
||||
|
||||
const KV_OPTIONS: { label: string; value: 4 | 8 | null }[] = [
|
||||
{ label: "None (full precision)", value: null },
|
||||
{ label: "4-bit", value: 4 },
|
||||
{ label: "8-bit", value: 8 },
|
||||
];
|
||||
</script>
|
||||
|
||||
<HeaderNav showHome={true} />
|
||||
|
||||
{#if draft}
|
||||
<div class="min-h-screen bg-background text-foreground" in:fade={{ duration: 200 }}>
|
||||
<div class="max-w-2xl mx-auto px-6 py-8">
|
||||
<h1 class="text-2xl font-bold text-exo-yellow tracking-wider uppercase mb-8">Settings</h1>
|
||||
|
||||
<!-- Memory / Safety -->
|
||||
<section class="mb-10">
|
||||
<h2 class="text-sm font-semibold text-white/50 tracking-widest uppercase mb-4">Memory / Safety</h2>
|
||||
<div class="space-y-5">
|
||||
<!-- OOM Prevention Toggle -->
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<div class="text-sm text-white/90">OOM Prevention</div>
|
||||
<div class="text-xs text-white/40 mt-0.5">Stop generation when memory is low</div>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => { if (draft) draft.memory.oom_prevention = !draft.memory.oom_prevention; }}
|
||||
class="relative w-11 h-6 rounded-full transition-colors duration-200 cursor-pointer {draft.memory.oom_prevention ? 'bg-exo-yellow' : 'bg-exo-medium-gray'}"
|
||||
role="switch"
|
||||
aria-checked={draft.memory.oom_prevention}
|
||||
>
|
||||
<span
|
||||
class="absolute top-0.5 left-0.5 w-5 h-5 rounded-full bg-white shadow transition-transform duration-200 {draft.memory.oom_prevention ? 'translate-x-5' : 'translate-x-0'}"
|
||||
></span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Memory Threshold Slider -->
|
||||
<div>
|
||||
<div class="flex items-center justify-between mb-1.5">
|
||||
<div>
|
||||
<div class="text-sm text-white/90">Memory Threshold</div>
|
||||
<div class="text-xs text-white/40 mt-0.5">KV cache eviction triggers above this level</div>
|
||||
</div>
|
||||
<span class="text-sm font-mono text-exo-yellow">{(draft.memory.memory_threshold * 100).toFixed(0)}%</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min="0.5"
|
||||
max="0.99"
|
||||
step="0.01"
|
||||
bind:value={draft.memory.memory_threshold}
|
||||
class="w-full h-1.5 rounded-full appearance-none cursor-pointer bg-exo-medium-gray accent-exo-yellow"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Memory Floor -->
|
||||
<div>
|
||||
<div class="flex items-center justify-between mb-1.5">
|
||||
<div>
|
||||
<div class="text-sm text-white/90">Memory Floor</div>
|
||||
<div class="text-xs text-white/40 mt-0.5">Minimum free memory to reserve (GB)</div>
|
||||
</div>
|
||||
<span class="text-sm font-mono text-exo-yellow">{draft.memory.memory_floor_gb.toFixed(1)} GB</span>
|
||||
</div>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
max="64"
|
||||
step="0.5"
|
||||
bind:value={draft.memory.memory_floor_gb}
|
||||
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Generation / Performance -->
|
||||
<section class="mb-10">
|
||||
<h2 class="text-sm font-semibold text-white/50 tracking-widest uppercase mb-4">Generation / Performance</h2>
|
||||
<div class="space-y-5">
|
||||
<!-- Prefill Step Size -->
|
||||
<div>
|
||||
<div class="flex items-center justify-between mb-1.5">
|
||||
<div>
|
||||
<div class="text-sm text-white/90">Prefill Step Size</div>
|
||||
<div class="text-xs text-white/40 mt-0.5">Token chunk size during prompt processing</div>
|
||||
</div>
|
||||
<span class="text-sm font-mono text-exo-yellow">{draft.generation.prefill_step_size.toLocaleString()}</span>
|
||||
</div>
|
||||
<input
|
||||
type="number"
|
||||
min="128"
|
||||
max="32768"
|
||||
step="128"
|
||||
bind:value={draft.generation.prefill_step_size}
|
||||
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Max Tokens -->
|
||||
<div>
|
||||
<div class="flex items-center justify-between mb-1.5">
|
||||
<div>
|
||||
<div class="text-sm text-white/90">Max Tokens</div>
|
||||
<div class="text-xs text-white/40 mt-0.5">Maximum generation length per response</div>
|
||||
</div>
|
||||
<span class="text-sm font-mono text-exo-yellow">{draft.generation.max_tokens.toLocaleString()}</span>
|
||||
</div>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
max="131072"
|
||||
step="1024"
|
||||
bind:value={draft.generation.max_tokens}
|
||||
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- KV Cache Bits -->
|
||||
<div>
|
||||
<div class="mb-1.5">
|
||||
<div class="text-sm text-white/90">KV Cache Quantization</div>
|
||||
<div class="text-xs text-white/40 mt-0.5">Lower bits save memory at slight quality cost</div>
|
||||
</div>
|
||||
<select
|
||||
bind:value={draft.generation.kv_cache_bits}
|
||||
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50 cursor-pointer"
|
||||
>
|
||||
{#each KV_OPTIONS as opt}
|
||||
<option value={opt.value}>{opt.label}</option>
|
||||
{/each}
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Action Buttons -->
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
onclick={handleSave}
|
||||
disabled={loading}
|
||||
class="px-5 py-2 rounded text-sm font-semibold tracking-wider uppercase transition-colors cursor-pointer
|
||||
bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker
|
||||
disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{loading ? "Saving..." : "Save"}
|
||||
</button>
|
||||
<button
|
||||
onclick={handleReset}
|
||||
disabled={loading}
|
||||
class="px-5 py-2 rounded text-sm font-semibold tracking-wider uppercase transition-colors cursor-pointer
|
||||
border border-exo-light-gray/30 text-white/70 hover:border-exo-yellow/50 hover:text-exo-yellow
|
||||
disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Reset to Defaults
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="min-h-screen bg-background flex items-center justify-center">
|
||||
<div class="text-white/40 text-sm">Loading settings...</div>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -3,7 +3,6 @@
|
||||
import {
|
||||
listTraces,
|
||||
getTraceRawUrl,
|
||||
deleteTraces,
|
||||
type TraceListItem,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
@@ -11,51 +10,6 @@
|
||||
let traces = $state<TraceListItem[]>([]);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
let selectedIds = $state<Set<string>>(new Set());
|
||||
let deleting = $state(false);
|
||||
|
||||
let allSelected = $derived(
|
||||
traces.length > 0 && selectedIds.size === traces.length,
|
||||
);
|
||||
|
||||
function toggleSelect(taskId: string) {
|
||||
const next = new Set(selectedIds);
|
||||
if (next.has(taskId)) {
|
||||
next.delete(taskId);
|
||||
} else {
|
||||
next.add(taskId);
|
||||
}
|
||||
selectedIds = next;
|
||||
}
|
||||
|
||||
function toggleSelectAll() {
|
||||
if (allSelected) {
|
||||
selectedIds = new Set();
|
||||
} else {
|
||||
selectedIds = new Set(traces.map((t) => t.taskId));
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDelete() {
|
||||
if (selectedIds.size === 0) return;
|
||||
const count = selectedIds.size;
|
||||
if (
|
||||
!confirm(
|
||||
`Delete ${count} trace${count === 1 ? "" : "s"}? This cannot be undone.`,
|
||||
)
|
||||
)
|
||||
return;
|
||||
deleting = true;
|
||||
try {
|
||||
await deleteTraces([...selectedIds]);
|
||||
selectedIds = new Set();
|
||||
await refresh();
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to delete traces";
|
||||
} finally {
|
||||
deleting = false;
|
||||
}
|
||||
}
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return "0B";
|
||||
@@ -155,16 +109,6 @@
|
||||
</h1>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
{#if selectedIds.size > 0}
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-red-400 hover:text-red-300 transition-colors uppercase border border-red-500/40 px-2 py-1 rounded"
|
||||
onclick={handleDelete}
|
||||
disabled={deleting}
|
||||
>
|
||||
{deleting ? "Deleting..." : `Delete (${selectedIds.size})`}
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
@@ -199,41 +143,14 @@
|
||||
</div>
|
||||
{:else}
|
||||
<div class="space-y-3">
|
||||
<div class="flex items-center gap-2 px-1">
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono uppercase transition-colors {allSelected
|
||||
? 'text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
onclick={toggleSelectAll}
|
||||
>
|
||||
{allSelected ? "Deselect all" : "Select all"}
|
||||
</button>
|
||||
</div>
|
||||
{#each traces as trace}
|
||||
{@const isSelected = selectedIds.has(trace.taskId)}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="w-full text-left rounded border-l-2 border-r border-t border-b transition-all p-4 flex items-center justify-between gap-4 cursor-pointer {isSelected
|
||||
? 'bg-exo-yellow/10 border-l-exo-yellow border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30'
|
||||
: 'bg-exo-black/30 border-l-transparent border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30 hover:bg-white/[0.03]'}"
|
||||
onclick={() => toggleSelect(trace.taskId)}
|
||||
onkeydown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
toggleSelect(trace.taskId);
|
||||
}
|
||||
}}
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-sm font-mono transition-colors truncate block {isSelected
|
||||
? 'text-exo-yellow'
|
||||
: 'text-white hover:text-exo-yellow'}"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
|
||||
>
|
||||
{trace.taskId}
|
||||
</a>
|
||||
@@ -243,11 +160,7 @@
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<div
|
||||
class="flex items-center gap-2 shrink-0"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div class="flex items-center gap-2 shrink-0">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
|
||||
@@ -81,8 +81,6 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DeleteTracesRequest,
|
||||
DeleteTracesResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -168,6 +166,13 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.settings import (
|
||||
ExoSettings,
|
||||
load_settings,
|
||||
)
|
||||
from exo.shared.types.settings import (
|
||||
save_settings as save_settings_to_file,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -346,12 +351,13 @@ class API:
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.post("/v1/traces/delete")(self.delete_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
self.app.get("/onboarding")(self.get_onboarding)
|
||||
self.app.post("/onboarding")(self.complete_onboarding)
|
||||
self.app.get("/settings")(self.get_settings)
|
||||
self.app.post("/settings")(self.save_settings)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -1722,10 +1728,7 @@ class API:
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
trace_path = EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
if not trace_path.resolve().is_relative_to(EXO_TRACING_CACHE_DIR.resolve()):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid task ID: {task_id}")
|
||||
return trace_path
|
||||
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
|
||||
async def list_traces(self) -> TraceListResponse:
|
||||
traces: list[TraceListItem] = []
|
||||
@@ -1824,18 +1827,6 @@ class API:
|
||||
filename=f"trace_{task_id}.json",
|
||||
)
|
||||
|
||||
async def delete_traces(self, request: DeleteTracesRequest) -> DeleteTracesResponse:
|
||||
deleted: list[str] = []
|
||||
not_found: list[str] = []
|
||||
for task_id in request.task_ids:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
if trace_path.exists():
|
||||
trace_path.unlink()
|
||||
deleted.append(task_id)
|
||||
else:
|
||||
not_found.append(task_id)
|
||||
return DeleteTracesResponse(deleted=deleted, not_found=not_found)
|
||||
|
||||
async def get_onboarding(self) -> JSONResponse:
|
||||
return JSONResponse({"completed": ONBOARDING_COMPLETE_FILE.exists()})
|
||||
|
||||
@@ -1843,3 +1834,13 @@ class API:
|
||||
ONBOARDING_COMPLETE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
ONBOARDING_COMPLETE_FILE.write_text("true")
|
||||
return JSONResponse({"completed": True})
|
||||
|
||||
async def get_settings(self) -> JSONResponse:
|
||||
settings = load_settings()
|
||||
return JSONResponse(settings.model_dump())
|
||||
|
||||
async def save_settings(self, request: Request) -> JSONResponse:
|
||||
body = cast(object, await request.json())
|
||||
settings = ExoSettings.model_validate(body)
|
||||
save_settings_to_file(settings)
|
||||
return JSONResponse(settings.model_dump())
|
||||
|
||||
@@ -437,12 +437,3 @@ class TraceListItem(CamelCaseModel):
|
||||
|
||||
class TraceListResponse(CamelCaseModel):
|
||||
traces: list[TraceListItem]
|
||||
|
||||
|
||||
class DeleteTracesRequest(CamelCaseModel):
|
||||
task_ids: list[str]
|
||||
|
||||
|
||||
class DeleteTracesResponse(CamelCaseModel):
|
||||
deleted: list[str]
|
||||
not_found: list[str]
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import ctypes
|
||||
import sys
|
||||
from math import ceil
|
||||
from typing import Self, overload
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
@@ -149,3 +154,67 @@ class Memory(FrozenModel):
|
||||
unit = "B"
|
||||
|
||||
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
|
||||
|
||||
|
||||
def _load_memory_settings() -> tuple[float, "Memory"]:
|
||||
"""Load memory threshold and floor from settings (lazy import to avoid circular dep)."""
|
||||
from exo.shared.types.settings import load_settings
|
||||
|
||||
s = load_settings()
|
||||
return s.memory.memory_threshold, Memory.from_gb(s.memory.memory_floor_gb)
|
||||
|
||||
_libc: ctypes.CDLL | None = None
|
||||
|
||||
|
||||
def _macos_memorystatus_level() -> int:
|
||||
global _libc # noqa: PLW0603
|
||||
if _libc is None:
|
||||
_libc = ctypes.CDLL("/usr/lib/libSystem.B.dylib")
|
||||
level = ctypes.c_int(0)
|
||||
size = ctypes.c_size_t(ctypes.sizeof(ctypes.c_int))
|
||||
ret: int = _libc.sysctlbyname( # pyright: ignore[reportAny]
|
||||
b"kern.memorystatus_level",
|
||||
ctypes.byref(level),
|
||||
ctypes.byref(size),
|
||||
None,
|
||||
ctypes.c_size_t(0),
|
||||
)
|
||||
if ret != 0:
|
||||
raise OSError("sysctlbyname kern.memorystatus_level failed")
|
||||
return level.value
|
||||
|
||||
|
||||
def _get_macos_memory_pressure() -> float:
|
||||
try:
|
||||
return 1.0 - _macos_memorystatus_level() / 100.0
|
||||
except (OSError, FileNotFoundError):
|
||||
logger.warning("Using fallback memory pressure")
|
||||
return _fallback_memory_pressure()
|
||||
|
||||
|
||||
def _fallback_memory_pressure() -> float:
|
||||
vm = psutil.virtual_memory()
|
||||
return 1.0 - vm.available / vm.total
|
||||
|
||||
|
||||
def get_memory_pressure() -> float:
|
||||
if sys.platform == "darwin":
|
||||
return _get_macos_memory_pressure()
|
||||
return _fallback_memory_pressure()
|
||||
|
||||
|
||||
def get_memory_limit() -> Memory:
|
||||
threshold, floor = _load_memory_settings()
|
||||
total = psutil.virtual_memory().total
|
||||
safety = min(int(total * (1 - threshold)), floor.in_bytes)
|
||||
return Memory.from_bytes(total - safety)
|
||||
|
||||
|
||||
def get_memory_available_locally() -> Memory:
|
||||
total = Memory.from_bytes(psutil.virtual_memory().total)
|
||||
return get_memory_limit() - total * get_memory_pressure()
|
||||
|
||||
|
||||
def get_memory_pressure_threshold() -> float:
|
||||
total = psutil.virtual_memory().total
|
||||
return get_memory_limit().in_bytes / total
|
||||
|
||||
121
src/exo/shared/types/settings.py
Normal file
121
src/exo/shared/types/settings.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
import tomllib
|
||||
from typing import Literal
|
||||
|
||||
import psutil
|
||||
from pydantic import ConfigDict, Field, ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
return 0.80
|
||||
if total_gb >= 32:
|
||||
return 0.75
|
||||
return 0.70
|
||||
|
||||
|
||||
class MemorySettings(CamelCaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=None,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
strict=False,
|
||||
)
|
||||
|
||||
oom_prevention: bool = False
|
||||
memory_threshold: float = Field(default_factory=_default_memory_threshold, ge=0.0, le=1.0)
|
||||
memory_floor_gb: float = Field(default=5.0, ge=0.0)
|
||||
|
||||
|
||||
class GenerationSettings(CamelCaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=None,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
strict=False,
|
||||
)
|
||||
|
||||
prefill_step_size: int = Field(default=4096, ge=1)
|
||||
max_tokens: int = Field(default=32168, ge=1)
|
||||
kv_cache_bits: Literal[4, 8] | None = None
|
||||
|
||||
|
||||
class ExoSettings(CamelCaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=None,
|
||||
validate_by_name=True,
|
||||
extra="ignore",
|
||||
strict=False,
|
||||
)
|
||||
|
||||
memory: MemorySettings = Field(default_factory=MemorySettings)
|
||||
generation: GenerationSettings = Field(default_factory=GenerationSettings)
|
||||
|
||||
|
||||
_cached_settings: ExoSettings | None = None
|
||||
_cached_mtime: float = 0.0
|
||||
|
||||
|
||||
def load_settings() -> ExoSettings:
|
||||
global _cached_settings, _cached_mtime # noqa: PLW0603
|
||||
|
||||
try:
|
||||
mtime = EXO_CONFIG_FILE.stat().st_mtime
|
||||
if _cached_settings is not None and mtime == _cached_mtime:
|
||||
return _cached_settings
|
||||
with open(EXO_CONFIG_FILE, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
settings = ExoSettings.model_validate(data)
|
||||
_cached_mtime = mtime
|
||||
except FileNotFoundError:
|
||||
settings = ExoSettings()
|
||||
except (tomllib.TOMLDecodeError, ValidationError) as e:
|
||||
logger.warning(f"Invalid config file {EXO_CONFIG_FILE}: {e}")
|
||||
settings = ExoSettings()
|
||||
|
||||
# Env vars override config file for backward compat.
|
||||
env_threshold = os.environ.get("EXO_MEMORY_THRESHOLD")
|
||||
if env_threshold is not None:
|
||||
settings = settings.model_copy(
|
||||
update={"memory": settings.memory.model_copy(update={"memory_threshold": float(env_threshold)})}
|
||||
)
|
||||
env_floor = os.environ.get("EXO_MEMORY_FLOOR")
|
||||
if env_floor is not None:
|
||||
settings = settings.model_copy(
|
||||
update={"memory": settings.memory.model_copy(update={"memory_floor_gb": float(env_floor)})}
|
||||
)
|
||||
|
||||
_cached_settings = settings
|
||||
return settings
|
||||
|
||||
|
||||
def save_settings(settings: ExoSettings) -> None:
|
||||
global _cached_settings, _cached_mtime # noqa: PLW0603
|
||||
|
||||
EXO_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
lines = [
|
||||
"[memory]",
|
||||
f"oom_prevention = {'true' if settings.memory.oom_prevention else 'false'}",
|
||||
f"memory_threshold = {settings.memory.memory_threshold}",
|
||||
f"memory_floor_gb = {settings.memory.memory_floor_gb}",
|
||||
"",
|
||||
"[generation]",
|
||||
f"prefill_step_size = {settings.generation.prefill_step_size}",
|
||||
f"max_tokens = {settings.generation.max_tokens}",
|
||||
]
|
||||
if settings.generation.kv_cache_bits is not None:
|
||||
lines.append(f"kv_cache_bits = {settings.generation.kv_cache_bits}")
|
||||
|
||||
EXO_CONFIG_FILE.write_text("\n".join(lines) + "\n")
|
||||
|
||||
_cached_settings = settings
|
||||
_cached_mtime = EXO_CONFIG_FILE.stat().st_mtime
|
||||
@@ -12,7 +12,7 @@ from anyio import fail_after, open_process, to_thread
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from pydantic import ConfigDict, ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE, EXO_MODELS_DIR
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -295,6 +295,8 @@ class ThunderboltBridgeInfo(TaggedModel):
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
cfg_file = anyio.Path(EXO_CONFIG_FILE)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
@@ -12,31 +10,14 @@ from mlx_lm.models.cache import (
|
||||
)
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.memory import Memory, get_memory_pressure
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.settings import load_settings
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
return 0.80
|
||||
if total_gb >= 32:
|
||||
return 0.75
|
||||
return 0.70
|
||||
|
||||
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
|
||||
)
|
||||
|
||||
|
||||
class CacheSnapshot:
|
||||
"""Snapshot of states at a known token position."""
|
||||
|
||||
@@ -92,6 +73,15 @@ class KVPrefixCache:
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def force_evict_all(self) -> int:
|
||||
count = len(self.caches)
|
||||
self.clear()
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Force-evicted all {count} prefix cache entries due to memory pressure"
|
||||
)
|
||||
return count
|
||||
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
@@ -217,7 +207,7 @@ class KVPrefixCache:
|
||||
# Evict LRU entries until below threshold
|
||||
while (
|
||||
len(self.caches) > 0
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
and self.get_memory_used_percentage() > load_settings().memory.memory_threshold
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
@@ -230,7 +220,7 @@ class KVPrefixCache:
|
||||
)
|
||||
|
||||
def get_memory_used_percentage(self) -> float:
|
||||
local_pressure: float = get_memory_used_percentage()
|
||||
local_pressure: float = get_memory_pressure()
|
||||
|
||||
if self._group is None:
|
||||
return local_pressure
|
||||
@@ -299,15 +289,47 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def get_available_memory() -> Memory:
|
||||
mem: int = psutil.virtual_memory().available
|
||||
return Memory.from_bytes(mem)
|
||||
def _measure_single_cache_bytes(
|
||||
entry: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
|
||||
) -> int:
|
||||
if isinstance(entry, CacheList):
|
||||
return sum(
|
||||
_measure_single_cache_bytes(c) # pyright: ignore[reportArgumentType]
|
||||
for c in entry.caches
|
||||
)
|
||||
|
||||
total = 0
|
||||
if isinstance(entry, ArraysCache):
|
||||
state = entry.state # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
for arr in state: # pyright: ignore[reportUnknownVariableType]
|
||||
if isinstance(arr, mx.array):
|
||||
total += arr.nbytes
|
||||
return total
|
||||
|
||||
total = 0
|
||||
for attr_name in ("keys", "values"):
|
||||
val: object = getattr(entry, attr_name, None)
|
||||
if val is None:
|
||||
continue
|
||||
if isinstance(val, mx.array):
|
||||
total += val.nbytes
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for arr in val: # pyright: ignore[reportUnknownVariableType]
|
||||
if isinstance(arr, mx.array):
|
||||
total += arr.nbytes
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def get_memory_used_percentage() -> float:
|
||||
mem = psutil.virtual_memory()
|
||||
# percent is 0-100
|
||||
return float(mem.percent / 100)
|
||||
def measure_cache_bytes(cache: KVCacheType) -> int:
|
||||
return sum(_measure_single_cache_bytes(c) for c in cache)
|
||||
|
||||
|
||||
def measure_kv_cache_bytes_per_token(cache: KVCacheType) -> Memory:
|
||||
offset = cache_length(cache)
|
||||
if offset == 0:
|
||||
return Memory.from_bytes(0)
|
||||
return Memory.from_bytes(measure_cache_bytes(cache) // offset)
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
@@ -320,13 +342,14 @@ def make_kv_cache(
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
kv_cache_bits = load_settings().generation.kv_cache_bits
|
||||
if kv_cache_bits is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=kv_cache_bits)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
|
||||
@@ -18,8 +18,9 @@ from exo.shared.types.api import (
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.memory import Memory, get_memory_available_locally
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.settings import load_settings
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
@@ -32,17 +33,18 @@ from exo.worker.engines.mlx.cache import (
|
||||
encode_prompt,
|
||||
has_non_kv_caches,
|
||||
make_kv_cache,
|
||||
measure_kv_cache_bytes_per_token,
|
||||
snapshot_ssm_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_any,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -110,7 +112,7 @@ def prefill(
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=4096,
|
||||
prefill_step_size=load_settings().generation.prefill_step_size,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
@@ -148,7 +150,8 @@ def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> int:
|
||||
) -> tuple[int, Memory]:
|
||||
"""Run warmup inference and tokens_generated and bytes_per_token"""
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
warmup_prompt = apply_chat_template(
|
||||
@@ -187,9 +190,12 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
bytes_per_token = measure_kv_cache_bytes_per_token(cache)
|
||||
logger.info(f"Measured KV cache cost: {bytes_per_token} per token")
|
||||
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated
|
||||
return tokens_generated, bytes_per_token
|
||||
|
||||
|
||||
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
|
||||
@@ -267,6 +273,33 @@ def extract_top_logprobs(
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def _check_memory_budget(
|
||||
bytes_per_token: Memory,
|
||||
total_sequence_tokens: int,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> str | None:
|
||||
if bytes_per_token.in_bytes == 0:
|
||||
return None
|
||||
|
||||
estimated = bytes_per_token * total_sequence_tokens
|
||||
over_budget = estimated > get_memory_available_locally()
|
||||
|
||||
if not mx_any(over_budget, group):
|
||||
return None
|
||||
|
||||
if kv_prefix_cache is not None and kv_prefix_cache.force_evict_all() > 0:
|
||||
mx.clear_cache()
|
||||
over_budget = estimated > get_memory_available_locally()
|
||||
if not mx_any(over_budget, group):
|
||||
return None
|
||||
|
||||
return (
|
||||
"Not enough memory for this conversation. "
|
||||
"Please start a new conversation or compact your messages."
|
||||
)
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -275,7 +308,10 @@ def mlx_generate(
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
bytes_per_token: Memory | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
if bytes_per_token is None:
|
||||
bytes_per_token = Memory()
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
|
||||
@@ -307,6 +343,23 @@ def mlx_generate(
|
||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||
)
|
||||
|
||||
if bytes_per_token.in_bytes > 0 and load_settings().memory.oom_prevention:
|
||||
oom_error = _check_memory_budget(
|
||||
bytes_per_token=bytes_per_token,
|
||||
total_sequence_tokens=len(all_prompt_tokens),
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
if oom_error is not None:
|
||||
logger.warning(f"OOM prevention (prefill): {oom_error}")
|
||||
yield GenerationResponse(
|
||||
text=oom_error,
|
||||
token=0,
|
||||
finish_reason="error",
|
||||
usage=None,
|
||||
)
|
||||
return
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -342,7 +395,7 @@ def mlx_generate(
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-2:]
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
max_tokens = task.max_output_tokens or load_settings().generation.max_tokens
|
||||
accumulated_text = ""
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
@@ -99,13 +98,14 @@ def mlx_distributed_init(
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
coordination_file = str(
|
||||
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
coordination_file = None
|
||||
try:
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
@@ -128,6 +128,9 @@ def mlx_distributed_init(
|
||||
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
|
||||
)
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
@@ -147,6 +150,10 @@ def mlx_distributed_init(
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
|
||||
@@ -31,6 +31,12 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import (
|
||||
Memory,
|
||||
get_memory_pressure,
|
||||
get_memory_pressure_threshold,
|
||||
)
|
||||
from exo.shared.types.settings import load_settings
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
@@ -114,6 +120,7 @@ def main(
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
bytes_per_token = Memory.from_bytes(0)
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -225,12 +232,14 @@ def main(
|
||||
assert tokenizer
|
||||
|
||||
t = time.monotonic()
|
||||
toks = warmup_inference(
|
||||
toks, bytes_per_token = warmup_inference(
|
||||
model=cast(Model, inference_model),
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"warmed up by generating {toks} tokens, {bytes_per_token}/token for KV cache"
|
||||
)
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
)
|
||||
@@ -310,6 +319,7 @@ def main(
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
group=group,
|
||||
bytes_per_token=bytes_per_token,
|
||||
)
|
||||
|
||||
if tokenizer.has_thinking:
|
||||
@@ -336,6 +346,7 @@ def main(
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = check_for_cancel_every
|
||||
oom_stopped = False
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
@@ -344,7 +355,15 @@ def main(
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
oom_local = (
|
||||
load_settings().memory.oom_prevention
|
||||
and bytes_per_token.in_bytes > 0
|
||||
and get_memory_pressure()
|
||||
> get_memory_pressure_threshold()
|
||||
)
|
||||
if mx_any(want_to_cancel or oom_local, group):
|
||||
if not want_to_cancel:
|
||||
oom_stopped = True
|
||||
break
|
||||
|
||||
match response:
|
||||
@@ -400,6 +419,21 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
if oom_stopped and device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=model_id,
|
||||
error_message=(
|
||||
"Generation stopped: running out of memory. "
|
||||
"Please start a new conversation or compact "
|
||||
"your messages."
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except PrefillCancelled:
|
||||
logger.info(f"Prefill cancelled for task {task.task_id}")
|
||||
# can we make this more explicit?
|
||||
|
||||
@@ -15,6 +15,7 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
@@ -114,7 +115,9 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "warmup_inference", make_nothin((1, Memory.from_bytes(0)))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
|
||||
Reference in New Issue
Block a user