Compare commits

...

7 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
20ea13f047 use memory pressure 2026-02-26 13:16:28 +00:00
Ryuichi Leo Takashige
4ff1578140 cleanup 2026-02-25 21:46:31 +00:00
Ryuichi Leo Takashige
a5873bc1fd remove test script 2026-02-25 20:54:52 +00:00
Ryuichi Leo Takashige
dc1ce2a2cf cleanup 2026-02-25 20:54:02 +00:00
Ryuichi Leo Takashige
ff57b00dc6 cleanup 2026-02-25 20:44:56 +00:00
Ryuichi Leo Takashige
d3222c498a Loosen conditions 2026-02-25 19:36:05 +00:00
Ryuichi Leo Takashige
2f719d62a7 Handle low memory better 2026-02-25 19:25:18 +00:00
6 changed files with 229 additions and 38 deletions

View File

@@ -249,7 +249,8 @@ class ChunkedKVCache(KVCache):
...
class CacheList(_BaseCache):
def __init__(self, *caches) -> None: ...
caches: tuple[_BaseCache, ...]
def __init__(self, *caches: _BaseCache) -> None: ...
def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool:
...

View File

@@ -1,6 +1,12 @@
import ctypes
import os
import sys
from math import ceil
from typing import Self, overload
import psutil
from exo.shared.logging import logger
from exo.utils.pydantic_ext import FrozenModel
@@ -149,3 +155,78 @@ class Memory(FrozenModel):
unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
MEMORY_FLOOR = Memory.from_gb(float(os.environ.get("EXO_MEMORY_FLOOR", "5")))
_libc: ctypes.CDLL | None = None
def _macos_memorystatus_level() -> int:
global _libc # noqa: PLW0603
if _libc is None:
_libc = ctypes.CDLL("/usr/lib/libSystem.B.dylib")
level = ctypes.c_int(0)
size = ctypes.c_size_t(ctypes.sizeof(ctypes.c_int))
ret: int = _libc.sysctlbyname( # pyright: ignore[reportAny]
b"kern.memorystatus_level",
ctypes.byref(level),
ctypes.byref(size),
None,
ctypes.c_size_t(0),
)
if ret != 0:
raise OSError("sysctlbyname kern.memorystatus_level failed")
return level.value
def _get_macos_memory_pressure() -> float:
try:
return 1.0 - _macos_memorystatus_level() / 100.0
except (OSError, FileNotFoundError):
logger.warning("Using fallback memory pressure")
return _fallback_memory_pressure()
def _fallback_memory_pressure() -> float:
vm = psutil.virtual_memory()
return 1.0 - vm.available / vm.total
def get_memory_pressure() -> float:
if sys.platform == "darwin":
return _get_macos_memory_pressure()
return _fallback_memory_pressure()
def get_memory_limit() -> Memory:
total = psutil.virtual_memory().total
floor = min(int(total * (1 - MEMORY_THRESHOLD)), MEMORY_FLOOR.in_bytes)
return Memory.from_bytes(total - floor)
def get_memory_available_locally() -> Memory:
total = Memory.from_bytes(psutil.virtual_memory().total)
return get_memory_limit() - total * get_memory_pressure()
def get_memory_pressure_threshold() -> float:
total = psutil.virtual_memory().total
return get_memory_limit().in_bytes / total

View File

@@ -1,8 +1,6 @@
import os
from copy import deepcopy
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
@@ -12,31 +10,13 @@ from mlx_lm.models.cache import (
)
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.memory import MEMORY_THRESHOLD, Memory, get_memory_pressure
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:
return 0.80
if total_gb >= 32:
return 0.75
return 0.70
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
class CacheSnapshot:
"""Snapshot of states at a known token position."""
@@ -92,6 +72,15 @@ class KVPrefixCache:
self._snapshots.clear()
self._last_used.clear()
def force_evict_all(self) -> int:
count = len(self.caches)
self.clear()
if count > 0:
logger.info(
f"Force-evicted all {count} prefix cache entries due to memory pressure"
)
return count
def add_kv_cache(
self,
prompt_tokens: mx.array,
@@ -217,7 +206,7 @@ class KVPrefixCache:
# Evict LRU entries until below threshold
while (
len(self.caches) > 0
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
and self.get_memory_used_percentage() > MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
@@ -230,7 +219,7 @@ class KVPrefixCache:
)
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage()
local_pressure: float = get_memory_pressure()
if self._group is None:
return local_pressure
@@ -299,15 +288,47 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item())
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def _measure_single_cache_bytes(
entry: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
) -> int:
if isinstance(entry, CacheList):
return sum(
_measure_single_cache_bytes(c) # pyright: ignore[reportArgumentType]
for c in entry.caches
)
total = 0
if isinstance(entry, ArraysCache):
state = entry.state # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
for arr in state: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
total = 0
for attr_name in ("keys", "values"):
val: object = getattr(entry, attr_name, None)
if val is None:
continue
if isinstance(val, mx.array):
total += val.nbytes
elif isinstance(val, (tuple, list)):
for arr in val: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def measure_cache_bytes(cache: KVCacheType) -> int:
return sum(_measure_single_cache_bytes(c) for c in cache)
def measure_kv_cache_bytes_per_token(cache: KVCacheType) -> Memory:
offset = cache_length(cache)
if offset == 0:
return Memory.from_bytes(0)
return Memory.from_bytes(measure_cache_bytes(cache) // offset)
def make_kv_cache(

View File

@@ -18,7 +18,7 @@ from exo.shared.types.api import (
Usage,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.memory import Memory, get_memory_available_locally
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
@@ -32,6 +32,7 @@ from exo.worker.engines.mlx.cache import (
encode_prompt,
has_non_kv_caches,
make_kv_cache,
measure_kv_cache_bytes_per_token,
snapshot_ssm_states,
)
from exo.worker.engines.mlx.constants import (
@@ -43,6 +44,7 @@ from exo.worker.engines.mlx.constants import (
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
fix_unmatched_think_end_tokens,
mx_any,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -148,7 +150,8 @@ def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None,
) -> int:
) -> tuple[int, Memory]:
"""Run warmup inference and tokens_generated and bytes_per_token"""
content = "Prompt to warm up the inference engine. Repeat this."
warmup_prompt = apply_chat_template(
@@ -187,9 +190,12 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
bytes_per_token = measure_kv_cache_bytes_per_token(cache)
logger.info(f"Measured KV cache cost: {bytes_per_token} per token")
mx_barrier(group)
return tokens_generated
return tokens_generated, bytes_per_token
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
@@ -267,6 +273,33 @@ def extract_top_logprobs(
return selected_logprob, top_logprob_items
def _check_memory_budget(
bytes_per_token: Memory,
total_sequence_tokens: int,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
) -> str | None:
if bytes_per_token.in_bytes == 0:
return None
estimated = bytes_per_token * total_sequence_tokens
over_budget = estimated > get_memory_available_locally()
if not mx_any(over_budget, group):
return None
if kv_prefix_cache is not None and kv_prefix_cache.force_evict_all() > 0:
mx.clear_cache()
over_budget = estimated > get_memory_available_locally()
if not mx_any(over_budget, group):
return None
return (
"Not enough memory for this conversation. "
"Please start a new conversation or compact your messages."
)
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -275,7 +308,10 @@ def mlx_generate(
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
bytes_per_token: Memory | None = None,
) -> Generator[GenerationResponse]:
if bytes_per_token is None:
bytes_per_token = Memory()
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
@@ -307,6 +343,23 @@ def mlx_generate(
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
)
if bytes_per_token.in_bytes > 0:
oom_error = _check_memory_budget(
bytes_per_token=bytes_per_token,
total_sequence_tokens=len(all_prompt_tokens),
kv_prefix_cache=kv_prefix_cache,
group=group,
)
if oom_error is not None:
logger.warning(f"OOM prevention (prefill): {oom_error}")
yield GenerationResponse(
text=oom_error,
token=0,
finish_reason="error",
usage=None,
)
return
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens

View File

@@ -31,6 +31,11 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.memory import (
Memory,
get_memory_pressure,
get_memory_pressure_threshold,
)
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -114,6 +119,7 @@ def main(
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
bytes_per_token = Memory.from_bytes(0)
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -225,12 +231,14 @@ def main(
assert tokenizer
t = time.monotonic()
toks = warmup_inference(
toks, bytes_per_token = warmup_inference(
model=cast(Model, inference_model),
tokenizer=tokenizer,
group=group,
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"warmed up by generating {toks} tokens, {bytes_per_token}/token for KV cache"
)
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
@@ -310,6 +318,7 @@ def main(
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
group=group,
bytes_per_token=bytes_per_token,
)
if tokenizer.has_thinking:
@@ -336,6 +345,7 @@ def main(
completion_tokens = 0
tokens_since_last_cancel_check = check_for_cancel_every
oom_stopped = False
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
@@ -344,7 +354,14 @@ def main(
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
oom_local = (
bytes_per_token.in_bytes > 0
and get_memory_pressure()
> get_memory_pressure_threshold()
)
if mx_any(want_to_cancel or oom_local, group):
if not want_to_cancel:
oom_stopped = True
break
match response:
@@ -400,6 +417,21 @@ def main(
)
)
if oom_stopped and device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=model_id,
error_message=(
"Generation stopped: running out of memory. "
"Please start a new conversation or compact "
"your messages."
),
),
)
)
except PrefillCancelled:
logger.info(f"Prefill cancelled for task {task.task_id}")
# can we make this more explicit?

View File

@@ -15,6 +15,7 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -114,7 +115,9 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(
mlx_runner, "warmup_inference", make_nothin((1, Memory.from_bytes(0)))
)
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).