mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
7 Commits
remove-pyt
...
leo/handle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ff1578140 | ||
|
|
a5873bc1fd | ||
|
|
dc1ce2a2cf | ||
|
|
ff57b00dc6 | ||
|
|
d3222c498a | ||
|
|
2f719d62a7 | ||
|
|
ba611f9cd0 |
@@ -249,7 +249,8 @@ class ChunkedKVCache(KVCache):
|
||||
...
|
||||
|
||||
class CacheList(_BaseCache):
|
||||
def __init__(self, *caches) -> None: ...
|
||||
caches: tuple[_BaseCache, ...]
|
||||
def __init__(self, *caches: _BaseCache) -> None: ...
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
|
||||
@@ -27,6 +27,7 @@ dependencies = [
|
||||
"tomlkit>=0.14.0",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux==0.15.5",
|
||||
"python-multipart>=0.0.21",
|
||||
"msgspec>=0.19.0",
|
||||
"zstandard>=0.23.0",
|
||||
]
|
||||
|
||||
@@ -542,13 +542,10 @@ class InfoGatherer:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
t = TextReceiveStream(BufferedByteReceiveStream(p.stdout))
|
||||
while True:
|
||||
with anyio.fail_after(self.macmon_interval * 3):
|
||||
macmon_output = await t.receive()
|
||||
await self.info_sender.send(
|
||||
MacmonMetrics.from_raw_json(macmon_output)
|
||||
)
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
@@ -559,12 +556,8 @@ class InfoGatherer:
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"memory monitor failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"memory monitor silent for {self.macmon_interval * 3}s - reloading"
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning("Error in memory monitor")
|
||||
logger.warning(f"Error in macmon monitor: {e}")
|
||||
await anyio.sleep(self.macmon_interval)
|
||||
|
||||
@@ -32,7 +32,7 @@ def _default_memory_threshold() -> float:
|
||||
return 0.70
|
||||
|
||||
|
||||
_MEMORY_THRESHOLD = float(
|
||||
MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
|
||||
)
|
||||
|
||||
@@ -92,6 +92,15 @@ class KVPrefixCache:
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def force_evict_all(self) -> int:
|
||||
count = len(self.caches)
|
||||
self.clear()
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Force-evicted all {count} prefix cache entries due to memory pressure"
|
||||
)
|
||||
return count
|
||||
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
@@ -217,7 +226,7 @@ class KVPrefixCache:
|
||||
# Evict LRU entries until below threshold
|
||||
while (
|
||||
len(self.caches) > 0
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
and self.get_memory_used_percentage() > MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
@@ -310,6 +319,59 @@ def get_memory_used_percentage() -> float:
|
||||
return float(mem.percent / 100)
|
||||
|
||||
|
||||
def get_safety_floor() -> int:
|
||||
total = psutil.virtual_memory().total
|
||||
return min(int(total * 0.10), 5 * 1024**3)
|
||||
|
||||
|
||||
def get_memory_pressure_threshold() -> float:
|
||||
total = psutil.virtual_memory().total
|
||||
return 1.0 - get_safety_floor() / total
|
||||
|
||||
|
||||
def _measure_single_cache_bytes(
|
||||
entry: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
|
||||
) -> int:
|
||||
if isinstance(entry, CacheList):
|
||||
return sum(
|
||||
_measure_single_cache_bytes(c) # pyright: ignore[reportArgumentType]
|
||||
for c in entry.caches
|
||||
)
|
||||
|
||||
total = 0
|
||||
if isinstance(entry, ArraysCache):
|
||||
state = entry.state # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
for arr in state: # pyright: ignore[reportUnknownVariableType]
|
||||
if isinstance(arr, mx.array):
|
||||
total += arr.nbytes
|
||||
return total
|
||||
|
||||
total = 0
|
||||
for attr_name in ("keys", "values"):
|
||||
val: object = getattr(entry, attr_name, None)
|
||||
if val is None:
|
||||
continue
|
||||
if isinstance(val, mx.array):
|
||||
total += val.nbytes
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for arr in val: # pyright: ignore[reportUnknownVariableType]
|
||||
if isinstance(arr, mx.array):
|
||||
total += arr.nbytes
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def measure_cache_bytes(cache: KVCacheType) -> int:
|
||||
return sum(_measure_single_cache_bytes(c) for c in cache)
|
||||
|
||||
|
||||
def measure_kv_cache_bytes_per_token(cache: KVCacheType) -> int:
|
||||
offset = cache_length(cache)
|
||||
if offset == 0:
|
||||
return 0
|
||||
return measure_cache_bytes(cache) // offset
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
|
||||
@@ -4,6 +4,7 @@ from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
@@ -30,8 +31,10 @@ from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
encode_prompt,
|
||||
get_memory_pressure_threshold,
|
||||
has_non_kv_caches,
|
||||
make_kv_cache,
|
||||
measure_kv_cache_bytes_per_token,
|
||||
snapshot_ssm_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -43,6 +46,7 @@ from exo.worker.engines.mlx.constants import (
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_any,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -148,7 +152,8 @@ def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> int:
|
||||
) -> tuple[int, int]:
|
||||
"""Run warmup inference and tokens_generated and bytes_per_token"""
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
warmup_prompt = apply_chat_template(
|
||||
@@ -187,9 +192,12 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
bytes_per_token = measure_kv_cache_bytes_per_token(cache)
|
||||
logger.info(f"Measured KV cache cost: {bytes_per_token} bytes per token")
|
||||
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated
|
||||
return tokens_generated, bytes_per_token
|
||||
|
||||
|
||||
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
|
||||
@@ -267,6 +275,37 @@ def extract_top_logprobs(
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def _check_memory_budget(
|
||||
bytes_per_token: int,
|
||||
total_sequence_tokens: int,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> str | None:
|
||||
if bytes_per_token == 0:
|
||||
return None
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
estimated = bytes_per_token * total_sequence_tokens / mem.total
|
||||
projected = mem.percent / 100 + estimated
|
||||
threshold = get_memory_pressure_threshold()
|
||||
|
||||
if not mx_any(projected > threshold, group):
|
||||
return None
|
||||
|
||||
if kv_prefix_cache is not None and kv_prefix_cache.force_evict_all() > 0:
|
||||
mx.clear_cache()
|
||||
mem = psutil.virtual_memory()
|
||||
projected = mem.percent / 100 + estimated
|
||||
if not mx_any(projected > threshold, group):
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Not enough memory for this conversation ({projected:.0%} projected, "
|
||||
f"{threshold:.0%} limit). "
|
||||
f"Please start a new conversation or compact your messages."
|
||||
)
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -275,6 +314,7 @@ def mlx_generate(
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
bytes_per_token: int = 0,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -307,6 +347,23 @@ def mlx_generate(
|
||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||
)
|
||||
|
||||
if bytes_per_token > 0:
|
||||
oom_error = _check_memory_budget(
|
||||
bytes_per_token=bytes_per_token,
|
||||
total_sequence_tokens=len(all_prompt_tokens),
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
if oom_error is not None:
|
||||
logger.warning(f"OOM prevention (prefill): {oom_error}")
|
||||
yield GenerationResponse(
|
||||
text=oom_error,
|
||||
token=0,
|
||||
finish_reason="error",
|
||||
usage=None,
|
||||
)
|
||||
return
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import cache
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
@@ -64,7 +65,7 @@ from exo.shared.types.worker.runners import (
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, get_memory_pressure_threshold
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
PrefillCancelled,
|
||||
mlx_generate,
|
||||
@@ -114,6 +115,7 @@ def main(
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
bytes_per_token: int = 0
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -225,12 +227,14 @@ def main(
|
||||
assert tokenizer
|
||||
|
||||
t = time.monotonic()
|
||||
toks = warmup_inference(
|
||||
toks, bytes_per_token = warmup_inference(
|
||||
model=cast(Model, inference_model),
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"warmed up by generating {toks} tokens, {bytes_per_token} bytes/token for KV cache"
|
||||
)
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
)
|
||||
@@ -310,6 +314,7 @@ def main(
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
group=group,
|
||||
bytes_per_token=bytes_per_token,
|
||||
)
|
||||
|
||||
if tokenizer.has_thinking:
|
||||
@@ -336,6 +341,7 @@ def main(
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = check_for_cancel_every
|
||||
oom_stopped = False
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
@@ -344,7 +350,14 @@ def main(
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
oom_local = (
|
||||
bytes_per_token > 0
|
||||
and psutil.virtual_memory().percent / 100
|
||||
> get_memory_pressure_threshold()
|
||||
)
|
||||
if mx_any(want_to_cancel or oom_local, group):
|
||||
if not want_to_cancel:
|
||||
oom_stopped = True
|
||||
break
|
||||
|
||||
match response:
|
||||
@@ -400,6 +413,21 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
if oom_stopped and device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=model_id,
|
||||
error_message=(
|
||||
"Generation stopped: running out of memory. "
|
||||
"Please start a new conversation or compact "
|
||||
"your messages."
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except PrefillCancelled:
|
||||
logger.info(f"Prefill cancelled for task {task.task_id}")
|
||||
# can we make this more explicit?
|
||||
|
||||
@@ -114,7 +114,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin((1, 0)))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
|
||||
11
uv.lock
generated
11
uv.lock
generated
@@ -385,6 +385,7 @@ dependencies = [
|
||||
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "python-multipart", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -423,6 +424,7 @@ requires-dist = [
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.11.7" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.21" },
|
||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||
{ name = "tiktoken", specifier = ">=0.12.0" },
|
||||
{ name = "tomlkit", specifier = ">=0.14.0" },
|
||||
@@ -1882,6 +1884,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.21"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/78/96/804520d0850c7db98e5ccb70282e29208723f0964e88ffd9d0da2f52ea09/python_multipart-0.0.21.tar.gz", hash = "sha256:7137ebd4d3bbf70ea1622998f902b97a29434a9e8dc40eb203bbcf7c2a2cba92", size = 37196, upload-time = "2025-12-17T09:24:22.446Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/76/03af049af4dcee5d27442f71b6924f01f3efb5d2bd34f23fcd563f2cc5f5/python_multipart-0.0.21-py3-none-any.whl", hash = "sha256:cf7a6713e01c87aa35387f4774e812c4361150938d20d232800f75ffcf266090", size = 24541, upload-time = "2025-12-17T09:24:21.153Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.3"
|
||||
|
||||
Reference in New Issue
Block a user