Compare commits

...

8 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
0955966b2a Claude-generated settings, no idea if it works 2026-02-26 13:49:05 +00:00
Ryuichi Leo Takashige
20ea13f047 use memory pressure 2026-02-26 13:16:28 +00:00
Ryuichi Leo Takashige
4ff1578140 cleanup 2026-02-25 21:46:31 +00:00
Ryuichi Leo Takashige
a5873bc1fd remove test script 2026-02-25 20:54:52 +00:00
Ryuichi Leo Takashige
dc1ce2a2cf cleanup 2026-02-25 20:54:02 +00:00
Ryuichi Leo Takashige
ff57b00dc6 cleanup 2026-02-25 20:44:56 +00:00
Ryuichi Leo Takashige
d3222c498a Loosen conditions 2026-02-25 19:36:05 +00:00
Ryuichi Leo Takashige
2f719d62a7 Handle low memory better 2026-02-25 19:25:18 +00:00
12 changed files with 675 additions and 45 deletions

View File

@@ -249,7 +249,8 @@ class ChunkedKVCache(KVCache):
... ...
class CacheList(_BaseCache): class CacheList(_BaseCache):
def __init__(self, *caches) -> None: ... caches: tuple[_BaseCache, ...]
def __init__(self, *caches: _BaseCache) -> None: ...
def __getitem__(self, idx): ... def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool: def is_trimmable(self): # -> bool:
... ...

View File

@@ -170,5 +170,30 @@
{/if} {/if}
Downloads Downloads
</a> </a>
<a
href="/#/settings"
class="text-sm text-white/70 hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
title="Settings"
>
<svg
class="w-4 h-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
/>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
/>
</svg>
Settings
</a>
</nav> </nav>
</header> </header>

View File

@@ -0,0 +1,87 @@
/**
* SettingsStore - Manages exo runtime settings via the /settings API.
*/
export interface MemorySettings {
oom_prevention: boolean;
memory_threshold: number;
memory_floor_gb: number;
}
export interface GenerationSettings {
prefill_step_size: number;
max_tokens: number;
kv_cache_bits: 4 | 8 | null;
}
export interface ExoSettings {
memory: MemorySettings;
generation: GenerationSettings;
}
function defaultSettings(): ExoSettings {
return {
memory: {
oom_prevention: false,
memory_threshold: 0.8,
memory_floor_gb: 5.0,
},
generation: {
prefill_step_size: 4096,
max_tokens: 32168,
kv_cache_bits: null,
},
};
}
class SettingsStore {
settings = $state<ExoSettings>(defaultSettings());
loading = $state(false);
error = $state<string | null>(null);
async load(): Promise<void> {
this.loading = true;
this.error = null;
try {
const response = await fetch("/settings");
if (!response.ok) {
throw new Error(`Failed to fetch settings: ${response.status}`);
}
this.settings = (await response.json()) as ExoSettings;
} catch (err) {
console.error("Failed to load settings:", err);
this.error = err instanceof Error ? err.message : "Unknown error";
} finally {
this.loading = false;
}
}
async save(updated: ExoSettings): Promise<boolean> {
this.loading = true;
this.error = null;
try {
const response = await fetch("/settings", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(updated),
});
if (!response.ok) {
throw new Error(`Failed to save settings: ${response.status}`);
}
this.settings = (await response.json()) as ExoSettings;
return true;
} catch (err) {
console.error("Failed to save settings:", err);
this.error = err instanceof Error ? err.message : "Unknown error";
return false;
} finally {
this.loading = false;
}
}
resetToDefaults(): ExoSettings {
return defaultSettings();
}
}
export const settingsStore = new SettingsStore();

View File

@@ -0,0 +1,193 @@
<script lang="ts">
import { onMount } from "svelte";
import { fade } from "svelte/transition";
import HeaderNav from "$lib/components/HeaderNav.svelte";
import { settingsStore, type ExoSettings } from "$lib/stores/settings.svelte";
import { addToast } from "$lib/stores/toast.svelte";
let draft = $state<ExoSettings | null>(null);
const loading = $derived(settingsStore.loading);
onMount(async () => {
await settingsStore.load();
draft = structuredClone(settingsStore.settings);
});
async function handleSave() {
if (!draft) return;
const ok = await settingsStore.save(draft);
if (ok) {
addToast({ type: "success", message: "Settings saved" });
} else {
addToast({ type: "error", message: settingsStore.error ?? "Failed to save settings" });
}
}
function handleReset() {
draft = settingsStore.resetToDefaults();
}
const KV_OPTIONS: { label: string; value: 4 | 8 | null }[] = [
{ label: "None (full precision)", value: null },
{ label: "4-bit", value: 4 },
{ label: "8-bit", value: 8 },
];
</script>
<HeaderNav showHome={true} />
{#if draft}
<div class="min-h-screen bg-background text-foreground" in:fade={{ duration: 200 }}>
<div class="max-w-2xl mx-auto px-6 py-8">
<h1 class="text-2xl font-bold text-exo-yellow tracking-wider uppercase mb-8">Settings</h1>
<!-- Memory / Safety -->
<section class="mb-10">
<h2 class="text-sm font-semibold text-white/50 tracking-widest uppercase mb-4">Memory / Safety</h2>
<div class="space-y-5">
<!-- OOM Prevention Toggle -->
<div class="flex items-center justify-between">
<div>
<div class="text-sm text-white/90">OOM Prevention</div>
<div class="text-xs text-white/40 mt-0.5">Stop generation when memory is low</div>
</div>
<button
onclick={() => { if (draft) draft.memory.oom_prevention = !draft.memory.oom_prevention; }}
class="relative w-11 h-6 rounded-full transition-colors duration-200 cursor-pointer {draft.memory.oom_prevention ? 'bg-exo-yellow' : 'bg-exo-medium-gray'}"
role="switch"
aria-checked={draft.memory.oom_prevention}
>
<span
class="absolute top-0.5 left-0.5 w-5 h-5 rounded-full bg-white shadow transition-transform duration-200 {draft.memory.oom_prevention ? 'translate-x-5' : 'translate-x-0'}"
></span>
</button>
</div>
<!-- Memory Threshold Slider -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Memory Threshold</div>
<div class="text-xs text-white/40 mt-0.5">KV cache eviction triggers above this level</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{(draft.memory.memory_threshold * 100).toFixed(0)}%</span>
</div>
<input
type="range"
min="0.5"
max="0.99"
step="0.01"
bind:value={draft.memory.memory_threshold}
class="w-full h-1.5 rounded-full appearance-none cursor-pointer bg-exo-medium-gray accent-exo-yellow"
/>
</div>
<!-- Memory Floor -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Memory Floor</div>
<div class="text-xs text-white/40 mt-0.5">Minimum free memory to reserve (GB)</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{draft.memory.memory_floor_gb.toFixed(1)} GB</span>
</div>
<input
type="number"
min="0"
max="64"
step="0.5"
bind:value={draft.memory.memory_floor_gb}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
/>
</div>
</div>
</section>
<!-- Generation / Performance -->
<section class="mb-10">
<h2 class="text-sm font-semibold text-white/50 tracking-widest uppercase mb-4">Generation / Performance</h2>
<div class="space-y-5">
<!-- Prefill Step Size -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Prefill Step Size</div>
<div class="text-xs text-white/40 mt-0.5">Token chunk size during prompt processing</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{draft.generation.prefill_step_size.toLocaleString()}</span>
</div>
<input
type="number"
min="128"
max="32768"
step="128"
bind:value={draft.generation.prefill_step_size}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<!-- Max Tokens -->
<div>
<div class="flex items-center justify-between mb-1.5">
<div>
<div class="text-sm text-white/90">Max Tokens</div>
<div class="text-xs text-white/40 mt-0.5">Maximum generation length per response</div>
</div>
<span class="text-sm font-mono text-exo-yellow">{draft.generation.max_tokens.toLocaleString()}</span>
</div>
<input
type="number"
min="1"
max="131072"
step="1024"
bind:value={draft.generation.max_tokens}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<!-- KV Cache Bits -->
<div>
<div class="mb-1.5">
<div class="text-sm text-white/90">KV Cache Quantization</div>
<div class="text-xs text-white/40 mt-0.5">Lower bits save memory at slight quality cost</div>
</div>
<select
bind:value={draft.generation.kv_cache_bits}
class="w-full bg-exo-medium-gray border border-exo-light-gray/20 rounded px-3 py-1.5 text-sm text-white/90 font-mono focus:outline-none focus:border-exo-yellow/50 cursor-pointer"
>
{#each KV_OPTIONS as opt}
<option value={opt.value}>{opt.label}</option>
{/each}
</select>
</div>
</div>
</section>
<!-- Action Buttons -->
<div class="flex items-center gap-3">
<button
onclick={handleSave}
disabled={loading}
class="px-5 py-2 rounded text-sm font-semibold tracking-wider uppercase transition-colors cursor-pointer
bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker
disabled:opacity-50 disabled:cursor-not-allowed"
>
{loading ? "Saving..." : "Save"}
</button>
<button
onclick={handleReset}
disabled={loading}
class="px-5 py-2 rounded text-sm font-semibold tracking-wider uppercase transition-colors cursor-pointer
border border-exo-light-gray/30 text-white/70 hover:border-exo-yellow/50 hover:text-exo-yellow
disabled:opacity-50 disabled:cursor-not-allowed"
>
Reset to Defaults
</button>
</div>
</div>
</div>
{:else}
<div class="min-h-screen bg-background flex items-center justify-center">
<div class="text-white/40 text-sm">Loading settings...</div>
</div>
{/if}

View File

@@ -166,6 +166,13 @@ from exo.shared.types.openai_responses import (
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
) )
from exo.shared.types.settings import (
ExoSettings,
load_settings,
)
from exo.shared.types.settings import (
save_settings as save_settings_to_file,
)
from exo.shared.types.state import State from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -349,6 +356,8 @@ class API:
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw) self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
self.app.get("/onboarding")(self.get_onboarding) self.app.get("/onboarding")(self.get_onboarding)
self.app.post("/onboarding")(self.complete_onboarding) self.app.post("/onboarding")(self.complete_onboarding)
self.app.get("/settings")(self.get_settings)
self.app.post("/settings")(self.save_settings)
async def place_instance(self, payload: PlaceInstanceParams): async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance( command = PlaceInstance(
@@ -1825,3 +1834,13 @@ class API:
ONBOARDING_COMPLETE_FILE.parent.mkdir(parents=True, exist_ok=True) ONBOARDING_COMPLETE_FILE.parent.mkdir(parents=True, exist_ok=True)
ONBOARDING_COMPLETE_FILE.write_text("true") ONBOARDING_COMPLETE_FILE.write_text("true")
return JSONResponse({"completed": True}) return JSONResponse({"completed": True})
async def get_settings(self) -> JSONResponse:
settings = load_settings()
return JSONResponse(settings.model_dump())
async def save_settings(self, request: Request) -> JSONResponse:
body = cast(object, await request.json())
settings = ExoSettings.model_validate(body)
save_settings_to_file(settings)
return JSONResponse(settings.model_dump())

View File

@@ -1,6 +1,11 @@
import ctypes
import sys
from math import ceil from math import ceil
from typing import Self, overload from typing import Self, overload
import psutil
from exo.shared.logging import logger
from exo.utils.pydantic_ext import FrozenModel from exo.utils.pydantic_ext import FrozenModel
@@ -149,3 +154,67 @@ class Memory(FrozenModel):
unit = "B" unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}" return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
def _load_memory_settings() -> tuple[float, "Memory"]:
"""Load memory threshold and floor from settings (lazy import to avoid circular dep)."""
from exo.shared.types.settings import load_settings
s = load_settings()
return s.memory.memory_threshold, Memory.from_gb(s.memory.memory_floor_gb)
_libc: ctypes.CDLL | None = None
def _macos_memorystatus_level() -> int:
global _libc # noqa: PLW0603
if _libc is None:
_libc = ctypes.CDLL("/usr/lib/libSystem.B.dylib")
level = ctypes.c_int(0)
size = ctypes.c_size_t(ctypes.sizeof(ctypes.c_int))
ret: int = _libc.sysctlbyname( # pyright: ignore[reportAny]
b"kern.memorystatus_level",
ctypes.byref(level),
ctypes.byref(size),
None,
ctypes.c_size_t(0),
)
if ret != 0:
raise OSError("sysctlbyname kern.memorystatus_level failed")
return level.value
def _get_macos_memory_pressure() -> float:
try:
return 1.0 - _macos_memorystatus_level() / 100.0
except (OSError, FileNotFoundError):
logger.warning("Using fallback memory pressure")
return _fallback_memory_pressure()
def _fallback_memory_pressure() -> float:
vm = psutil.virtual_memory()
return 1.0 - vm.available / vm.total
def get_memory_pressure() -> float:
if sys.platform == "darwin":
return _get_macos_memory_pressure()
return _fallback_memory_pressure()
def get_memory_limit() -> Memory:
threshold, floor = _load_memory_settings()
total = psutil.virtual_memory().total
safety = min(int(total * (1 - threshold)), floor.in_bytes)
return Memory.from_bytes(total - safety)
def get_memory_available_locally() -> Memory:
total = Memory.from_bytes(psutil.virtual_memory().total)
return get_memory_limit() - total * get_memory_pressure()
def get_memory_pressure_threshold() -> float:
total = psutil.virtual_memory().total
return get_memory_limit().in_bytes / total

View File

@@ -0,0 +1,121 @@
import os
import tomllib
from typing import Literal
import psutil
from pydantic import ConfigDict, Field, ValidationError
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.logging import logger
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
class MemorySettings(CamelCaseModel):
model_config = ConfigDict(
alias_generator=None,
validate_by_name=True,
extra="forbid",
strict=False,
)
oom_prevention: bool = False
memory_threshold: float = Field(default_factory=_default_memory_threshold, ge=0.0, le=1.0)
memory_floor_gb: float = Field(default=5.0, ge=0.0)
class GenerationSettings(CamelCaseModel):
model_config = ConfigDict(
alias_generator=None,
validate_by_name=True,
extra="forbid",
strict=False,
)
prefill_step_size: int = Field(default=4096, ge=1)
max_tokens: int = Field(default=32168, ge=1)
kv_cache_bits: Literal[4, 8] | None = None
class ExoSettings(CamelCaseModel):
model_config = ConfigDict(
alias_generator=None,
validate_by_name=True,
extra="ignore",
strict=False,
)
memory: MemorySettings = Field(default_factory=MemorySettings)
generation: GenerationSettings = Field(default_factory=GenerationSettings)
_cached_settings: ExoSettings | None = None
_cached_mtime: float = 0.0
def load_settings() -> ExoSettings:
global _cached_settings, _cached_mtime # noqa: PLW0603
try:
mtime = EXO_CONFIG_FILE.stat().st_mtime
if _cached_settings is not None and mtime == _cached_mtime:
return _cached_settings
with open(EXO_CONFIG_FILE, "rb") as f:
data = tomllib.load(f)
settings = ExoSettings.model_validate(data)
_cached_mtime = mtime
except FileNotFoundError:
settings = ExoSettings()
except (tomllib.TOMLDecodeError, ValidationError) as e:
logger.warning(f"Invalid config file {EXO_CONFIG_FILE}: {e}")
settings = ExoSettings()
# Env vars override config file for backward compat.
env_threshold = os.environ.get("EXO_MEMORY_THRESHOLD")
if env_threshold is not None:
settings = settings.model_copy(
update={"memory": settings.memory.model_copy(update={"memory_threshold": float(env_threshold)})}
)
env_floor = os.environ.get("EXO_MEMORY_FLOOR")
if env_floor is not None:
settings = settings.model_copy(
update={"memory": settings.memory.model_copy(update={"memory_floor_gb": float(env_floor)})}
)
_cached_settings = settings
return settings
def save_settings(settings: ExoSettings) -> None:
global _cached_settings, _cached_mtime # noqa: PLW0603
EXO_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
lines = [
"[memory]",
f"oom_prevention = {'true' if settings.memory.oom_prevention else 'false'}",
f"memory_threshold = {settings.memory.memory_threshold}",
f"memory_floor_gb = {settings.memory.memory_floor_gb}",
"",
"[generation]",
f"prefill_step_size = {settings.generation.prefill_step_size}",
f"max_tokens = {settings.generation.max_tokens}",
]
if settings.generation.kv_cache_bits is not None:
lines.append(f"kv_cache_bits = {settings.generation.kv_cache_bits}")
EXO_CONFIG_FILE.write_text("\n".join(lines) + "\n")
_cached_settings = settings
_cached_mtime = EXO_CONFIG_FILE.stat().st_mtime

View File

@@ -12,7 +12,7 @@ from anyio import fail_after, open_process, to_thread
from anyio.streams.buffered import BufferedByteReceiveStream from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream from anyio.streams.text import TextReceiveStream
from loguru import logger from loguru import logger
from pydantic import ValidationError from pydantic import ConfigDict, ValidationError
from exo.shared.constants import EXO_CONFIG_FILE, EXO_MODELS_DIR from exo.shared.constants import EXO_CONFIG_FILE, EXO_MODELS_DIR
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
@@ -295,6 +295,8 @@ class ThunderboltBridgeInfo(TaggedModel):
class NodeConfig(TaggedModel): class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there""" """Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
model_config = ConfigDict(extra="ignore")
@classmethod @classmethod
async def gather(cls) -> Self | None: async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE) cfg_file = anyio.Path(EXO_CONFIG_FILE)

View File

@@ -1,8 +1,6 @@
import os
from copy import deepcopy from copy import deepcopy
import mlx.core as mx import mlx.core as mx
import psutil
from mlx_lm.models.cache import ( from mlx_lm.models.cache import (
ArraysCache, ArraysCache,
CacheList, CacheList,
@@ -12,31 +10,14 @@ from mlx_lm.models.cache import (
) )
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory, get_memory_pressure
from exo.shared.types.mlx import KVCacheType from exo.shared.types.mlx import KVCacheType
from exo.shared.types.settings import load_settings
from exo.worker.engines.mlx import Model from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE
from exo.worker.runner.bootstrap import logger from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
class CacheSnapshot: class CacheSnapshot:
"""Snapshot of states at a known token position.""" """Snapshot of states at a known token position."""
@@ -92,6 +73,15 @@ class KVPrefixCache:
self._snapshots.clear() self._snapshots.clear()
self._last_used.clear() self._last_used.clear()
def force_evict_all(self) -> int:
count = len(self.caches)
self.clear()
if count > 0:
logger.info(
f"Force-evicted all {count} prefix cache entries due to memory pressure"
)
return count
def add_kv_cache( def add_kv_cache(
self, self,
prompt_tokens: mx.array, prompt_tokens: mx.array,
@@ -217,7 +207,7 @@ class KVPrefixCache:
# Evict LRU entries until below threshold # Evict LRU entries until below threshold
while ( while (
len(self.caches) > 0 len(self.caches) > 0
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD and self.get_memory_used_percentage() > load_settings().memory.memory_threshold
): ):
lru_index = self._last_used.index(min(self._last_used)) lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index]) evicted_tokens = len(self.prompts[lru_index])
@@ -230,7 +220,7 @@ class KVPrefixCache:
) )
def get_memory_used_percentage(self) -> float: def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage() local_pressure: float = get_memory_pressure()
if self._group is None: if self._group is None:
return local_pressure return local_pressure
@@ -299,15 +289,47 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item()) return int(mx.sum(prefix_mask).item())
def get_available_memory() -> Memory: def _measure_single_cache_bytes(
mem: int = psutil.virtual_memory().available entry: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
return Memory.from_bytes(mem) ) -> int:
if isinstance(entry, CacheList):
return sum(
_measure_single_cache_bytes(c) # pyright: ignore[reportArgumentType]
for c in entry.caches
)
total = 0
if isinstance(entry, ArraysCache):
state = entry.state # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
for arr in state: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
total = 0
for attr_name in ("keys", "values"):
val: object = getattr(entry, attr_name, None)
if val is None:
continue
if isinstance(val, mx.array):
total += val.nbytes
elif isinstance(val, (tuple, list)):
for arr in val: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
def get_memory_used_percentage() -> float: def measure_cache_bytes(cache: KVCacheType) -> int:
mem = psutil.virtual_memory() return sum(_measure_single_cache_bytes(c) for c in cache)
# percent is 0-100
return float(mem.percent / 100)
def measure_kv_cache_bytes_per_token(cache: KVCacheType) -> Memory:
offset = cache_length(cache)
if offset == 0:
return Memory.from_bytes(0)
return Memory.from_bytes(measure_cache_bytes(cache) // offset)
def make_kv_cache( def make_kv_cache(
@@ -320,13 +342,14 @@ def make_kv_cache(
return model.make_cache() # type: ignore return model.make_cache() # type: ignore
if max_kv_size is None: if max_kv_size is None:
if KV_CACHE_BITS is None: kv_cache_bits = load_settings().generation.kv_cache_bits
if kv_cache_bits is None:
logger.info("Using default KV cache") logger.info("Using default KV cache")
return [KVCache() for _ in model.layers] return [KVCache() for _ in model.layers]
else: else:
logger.info("Using quantized KV cache") logger.info("Using quantized KV cache")
return [ return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS) QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=kv_cache_bits)
for _ in model.layers for _ in model.layers
] ]
else: else:

View File

@@ -18,8 +18,9 @@ from exo.shared.types.api import (
Usage, Usage,
) )
from exo.shared.types.common import ModelId from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory, get_memory_available_locally
from exo.shared.types.mlx import KVCacheType from exo.shared.types.mlx import KVCacheType
from exo.shared.types.settings import load_settings
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import ( from exo.shared.types.worker.runner_response import (
GenerationResponse, GenerationResponse,
@@ -32,17 +33,18 @@ from exo.worker.engines.mlx.cache import (
encode_prompt, encode_prompt,
has_non_kv_caches, has_non_kv_caches,
make_kv_cache, make_kv_cache,
measure_kv_cache_bytes_per_token,
snapshot_ssm_states, snapshot_ssm_states,
) )
from exo.worker.engines.mlx.constants import ( from exo.worker.engines.mlx.constants import (
DEFAULT_TOP_LOGPROBS, DEFAULT_TOP_LOGPROBS,
KV_BITS, KV_BITS,
KV_GROUP_SIZE, KV_GROUP_SIZE,
MAX_TOKENS,
) )
from exo.worker.engines.mlx.utils_mlx import ( from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template, apply_chat_template,
fix_unmatched_think_end_tokens, fix_unmatched_think_end_tokens,
mx_any,
mx_barrier, mx_barrier,
) )
from exo.worker.runner.bootstrap import logger from exo.worker.runner.bootstrap import logger
@@ -110,7 +112,7 @@ def prefill(
max_tokens=1, max_tokens=1,
sampler=sampler, sampler=sampler,
prompt_cache=cache, prompt_cache=cache,
prefill_step_size=4096, prefill_step_size=load_settings().generation.prefill_step_size,
kv_group_size=KV_GROUP_SIZE, kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS, kv_bits=KV_BITS,
prompt_progress_callback=progress_callback, prompt_progress_callback=progress_callback,
@@ -148,7 +150,8 @@ def warmup_inference(
model: Model, model: Model,
tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None, group: mx.distributed.Group | None,
) -> int: ) -> tuple[int, Memory]:
"""Run warmup inference and tokens_generated and bytes_per_token"""
content = "Prompt to warm up the inference engine. Repeat this." content = "Prompt to warm up the inference engine. Repeat this."
warmup_prompt = apply_chat_template( warmup_prompt = apply_chat_template(
@@ -187,9 +190,12 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens") logger.info("Generated ALL warmup tokens")
bytes_per_token = measure_kv_cache_bytes_per_token(cache)
logger.info(f"Measured KV cache cost: {bytes_per_token} per token")
mx_barrier(group) mx_barrier(group)
return tokens_generated return tokens_generated, bytes_per_token
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]: def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
@@ -267,6 +273,33 @@ def extract_top_logprobs(
return selected_logprob, top_logprob_items return selected_logprob, top_logprob_items
def _check_memory_budget(
bytes_per_token: Memory,
total_sequence_tokens: int,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
) -> str | None:
if bytes_per_token.in_bytes == 0:
return None
estimated = bytes_per_token * total_sequence_tokens
over_budget = estimated > get_memory_available_locally()
if not mx_any(over_budget, group):
return None
if kv_prefix_cache is not None and kv_prefix_cache.force_evict_all() > 0:
mx.clear_cache()
over_budget = estimated > get_memory_available_locally()
if not mx_any(over_budget, group):
return None
return (
"Not enough memory for this conversation. "
"Please start a new conversation or compact your messages."
)
def mlx_generate( def mlx_generate(
model: Model, model: Model,
tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper,
@@ -275,7 +308,10 @@ def mlx_generate(
kv_prefix_cache: KVPrefixCache | None, kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None, group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None, on_prefill_progress: Callable[[int, int], None] | None = None,
bytes_per_token: Memory | None = None,
) -> Generator[GenerationResponse]: ) -> Generator[GenerationResponse]:
if bytes_per_token is None:
bytes_per_token = Memory()
# Ensure that generation stats only contains peak memory for this generation # Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory() mx.reset_peak_memory()
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42. # TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
@@ -307,6 +343,23 @@ def mlx_generate(
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)" f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
) )
if bytes_per_token.in_bytes > 0 and load_settings().memory.oom_prevention:
oom_error = _check_memory_budget(
bytes_per_token=bytes_per_token,
total_sequence_tokens=len(all_prompt_tokens),
kv_prefix_cache=kv_prefix_cache,
group=group,
)
if oom_error is not None:
logger.warning(f"OOM prevention (prefill): {oom_error}")
yield GenerationResponse(
text=oom_error,
token=0,
finish_reason="error",
usage=None,
)
return
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = [] logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench: if is_bench:
# Only sample length eos tokens # Only sample length eos tokens
@@ -342,7 +395,7 @@ def mlx_generate(
# stream_generate starts from the last token # stream_generate starts from the last token
last_token = prompt_tokens[-2:] last_token = prompt_tokens[-2:]
max_tokens = task.max_output_tokens or MAX_TOKENS max_tokens = task.max_output_tokens or load_settings().generation.max_tokens
accumulated_text = "" accumulated_text = ""
generated_text_parts: list[str] = [] generated_text_parts: list[str] = []
generation_start_time = time.perf_counter() generation_start_time = time.perf_counter()

View File

@@ -31,6 +31,12 @@ from exo.shared.types.events import (
TaskAcknowledged, TaskAcknowledged,
TaskStatusUpdated, TaskStatusUpdated,
) )
from exo.shared.types.memory import (
Memory,
get_memory_pressure,
get_memory_pressure_threshold,
)
from exo.shared.types.settings import load_settings
from exo.shared.types.tasks import ( from exo.shared.types.tasks import (
ConnectToGroup, ConnectToGroup,
LoadModel, LoadModel,
@@ -114,6 +120,7 @@ def main(
group = None group = None
kv_prefix_cache: KVPrefixCache | None = None kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None check_for_cancel_every: int | None = None
bytes_per_token = Memory.from_bytes(0)
current_status: RunnerStatus = RunnerIdle() current_status: RunnerStatus = RunnerIdle()
logger.info("runner created") logger.info("runner created")
@@ -225,12 +232,14 @@ def main(
assert tokenizer assert tokenizer
t = time.monotonic() t = time.monotonic()
toks = warmup_inference( toks, bytes_per_token = warmup_inference(
model=cast(Model, inference_model), model=cast(Model, inference_model),
tokenizer=tokenizer, tokenizer=tokenizer,
group=group, group=group,
) )
logger.info(f"warmed up by generating {toks} tokens") logger.info(
f"warmed up by generating {toks} tokens, {bytes_per_token}/token for KV cache"
)
check_for_cancel_every = min( check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100 math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
) )
@@ -310,6 +319,7 @@ def main(
kv_prefix_cache=kv_prefix_cache, kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress, on_prefill_progress=on_prefill_progress,
group=group, group=group,
bytes_per_token=bytes_per_token,
) )
if tokenizer.has_thinking: if tokenizer.has_thinking:
@@ -336,6 +346,7 @@ def main(
completion_tokens = 0 completion_tokens = 0
tokens_since_last_cancel_check = check_for_cancel_every tokens_since_last_cancel_check = check_for_cancel_every
oom_stopped = False
for response in mlx_generator: for response in mlx_generator:
tokens_since_last_cancel_check += 1 tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every: if tokens_since_last_cancel_check >= check_for_cancel_every:
@@ -344,7 +355,15 @@ def main(
want_to_cancel = (task.task_id in cancelled_tasks) or ( want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
) )
if mx_any(want_to_cancel, group): oom_local = (
load_settings().memory.oom_prevention
and bytes_per_token.in_bytes > 0
and get_memory_pressure()
> get_memory_pressure_threshold()
)
if mx_any(want_to_cancel or oom_local, group):
if not want_to_cancel:
oom_stopped = True
break break
match response: match response:
@@ -400,6 +419,21 @@ def main(
) )
) )
if oom_stopped and device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=model_id,
error_message=(
"Generation stopped: running out of memory. "
"Please start a new conversation or compact "
"your messages."
),
),
)
)
except PrefillCancelled: except PrefillCancelled:
logger.info(f"Prefill cancelled for task {task.task_id}") logger.info(f"Prefill cancelled for task {task.task_id}")
# can we make this more explicit? # can we make this more explicit?

View File

@@ -15,6 +15,7 @@ from exo.shared.types.events import (
TaskAcknowledged, TaskAcknowledged,
TaskStatusUpdated, TaskStatusUpdated,
) )
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ( from exo.shared.types.tasks import (
ConnectToGroup, ConnectToGroup,
LoadModel, LoadModel,
@@ -114,7 +115,9 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group # initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup())) monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer))) monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1)) monkeypatch.setattr(
mlx_runner, "warmup_inference", make_nothin((1, Memory.from_bytes(0)))
)
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin) monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False)) monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1). # Mock apply_chat_template since we're using a fake tokenizer (integer 1).