Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
e8e5d3710f fix 2026-02-25 18:52:43 +00:00
8 changed files with 60 additions and 171 deletions

View File

@@ -249,8 +249,7 @@ class ChunkedKVCache(KVCache):
...
class CacheList(_BaseCache):
caches: tuple[_BaseCache, ...]
def __init__(self, *caches: _BaseCache) -> None: ...
def __init__(self, *caches) -> None: ...
def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool:
...

View File

@@ -524,15 +524,15 @@ class API:
if (
model_card.model_id,
sharding,
instance_meta,
instance.sharding(),
instance.instance_meta(),
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
sharding=instance.sharding(),
instance_meta=instance.instance_meta(),
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -541,8 +541,8 @@ class API:
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
instance.sharding(),
instance.instance_meta(),
len(placement_node_ids),
)
)

View File

@@ -4,7 +4,13 @@ from pydantic import model_validator
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
Sharding,
ShardMetadata,
TensorShardMetadata,
)
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)
@staticmethod
def instance_meta() -> InstanceMeta: ...
def sharding(self) -> Sharding:
if all(
isinstance(sm, PipelineShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Pipeline
if all(
isinstance(sm, TensorShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Tensor
raise ValueError("shard metadata malformed")
class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxRing
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxJaccl
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance

View File

@@ -542,10 +542,13 @@ class InfoGatherer:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
t = TextReceiveStream(BufferedByteReceiveStream(p.stdout))
while True:
with anyio.fail_after(self.macmon_interval * 3):
macmon_output = await t.receive()
await self.info_sender.send(
MacmonMetrics.from_raw_json(macmon_output)
)
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
@@ -556,8 +559,12 @@ class InfoGatherer:
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
f"memory monitor failed with return code {e.returncode}: {stderr_msg}"
)
except TimeoutError:
logger.warning(
f"memory monitor silent for {self.macmon_interval * 3}s - reloading"
)
except Exception as e:
logger.warning(f"Error in macmon monitor: {e}")
logger.opt(exception=e).warning("Error in memory monitor")
await anyio.sleep(self.macmon_interval)

View File

@@ -32,7 +32,7 @@ def _default_memory_threshold() -> float:
return 0.70
MEMORY_THRESHOLD = float(
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
)
@@ -92,15 +92,6 @@ class KVPrefixCache:
self._snapshots.clear()
self._last_used.clear()
def force_evict_all(self) -> int:
count = len(self.caches)
self.clear()
if count > 0:
logger.info(
f"Force-evicted all {count} prefix cache entries due to memory pressure"
)
return count
def add_kv_cache(
self,
prompt_tokens: mx.array,
@@ -226,7 +217,7 @@ class KVPrefixCache:
# Evict LRU entries until below threshold
while (
len(self.caches) > 0
and self.get_memory_used_percentage() > MEMORY_THRESHOLD
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
@@ -319,59 +310,6 @@ def get_memory_used_percentage() -> float:
return float(mem.percent / 100)
def get_safety_floor() -> int:
total = psutil.virtual_memory().total
return min(int(total * 0.10), 5 * 1024**3)
def get_memory_pressure_threshold() -> float:
total = psutil.virtual_memory().total
return 1.0 - get_safety_floor() / total
def _measure_single_cache_bytes(
entry: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
) -> int:
if isinstance(entry, CacheList):
return sum(
_measure_single_cache_bytes(c) # pyright: ignore[reportArgumentType]
for c in entry.caches
)
total = 0
if isinstance(entry, ArraysCache):
state = entry.state # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
for arr in state: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
total = 0
for attr_name in ("keys", "values"):
val: object = getattr(entry, attr_name, None)
if val is None:
continue
if isinstance(val, mx.array):
total += val.nbytes
elif isinstance(val, (tuple, list)):
for arr in val: # pyright: ignore[reportUnknownVariableType]
if isinstance(arr, mx.array):
total += arr.nbytes
return total
def measure_cache_bytes(cache: KVCacheType) -> int:
return sum(_measure_single_cache_bytes(c) for c in cache)
def measure_kv_cache_bytes_per_token(cache: KVCacheType) -> int:
offset = cache_length(cache)
if offset == 0:
return 0
return measure_cache_bytes(cache) // offset
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:

View File

@@ -4,7 +4,6 @@ from copy import deepcopy
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
import psutil
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler
@@ -31,10 +30,8 @@ from exo.worker.engines.mlx.cache import (
CacheSnapshot,
KVPrefixCache,
encode_prompt,
get_memory_pressure_threshold,
has_non_kv_caches,
make_kv_cache,
measure_kv_cache_bytes_per_token,
snapshot_ssm_states,
)
from exo.worker.engines.mlx.constants import (
@@ -46,7 +43,6 @@ from exo.worker.engines.mlx.constants import (
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
fix_unmatched_think_end_tokens,
mx_any,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -152,8 +148,7 @@ def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None,
) -> tuple[int, int]:
"""Run warmup inference and tokens_generated and bytes_per_token"""
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
warmup_prompt = apply_chat_template(
@@ -192,12 +187,9 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
bytes_per_token = measure_kv_cache_bytes_per_token(cache)
logger.info(f"Measured KV cache cost: {bytes_per_token} bytes per token")
mx_barrier(group)
return tokens_generated, bytes_per_token
return tokens_generated
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
@@ -275,37 +267,6 @@ def extract_top_logprobs(
return selected_logprob, top_logprob_items
def _check_memory_budget(
bytes_per_token: int,
total_sequence_tokens: int,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
) -> str | None:
if bytes_per_token == 0:
return None
mem = psutil.virtual_memory()
estimated = bytes_per_token * total_sequence_tokens / mem.total
projected = mem.percent / 100 + estimated
threshold = get_memory_pressure_threshold()
if not mx_any(projected > threshold, group):
return None
if kv_prefix_cache is not None and kv_prefix_cache.force_evict_all() > 0:
mx.clear_cache()
mem = psutil.virtual_memory()
projected = mem.percent / 100 + estimated
if not mx_any(projected > threshold, group):
return None
return (
f"Not enough memory for this conversation ({projected:.0%} projected, "
f"{threshold:.0%} limit). "
f"Please start a new conversation or compact your messages."
)
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -314,7 +275,6 @@ def mlx_generate(
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
bytes_per_token: int = 0,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -347,23 +307,6 @@ def mlx_generate(
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
)
if bytes_per_token > 0:
oom_error = _check_memory_budget(
bytes_per_token=bytes_per_token,
total_sequence_tokens=len(all_prompt_tokens),
kv_prefix_cache=kv_prefix_cache,
group=group,
)
if oom_error is not None:
logger.warning(f"OOM prevention (prefill): {oom_error}")
yield GenerationResponse(
text=oom_error,
token=0,
finish_reason="error",
usage=None,
)
return
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens

View File

@@ -6,7 +6,6 @@ from functools import cache
from typing import TYPE_CHECKING, cast
import mlx.core as mx
import psutil
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -65,7 +64,7 @@ from exo.shared.types.worker.runners import (
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, get_memory_pressure_threshold
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
mlx_generate,
@@ -115,7 +114,6 @@ def main(
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
bytes_per_token: int = 0
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -227,14 +225,12 @@ def main(
assert tokenizer
t = time.monotonic()
toks, bytes_per_token = warmup_inference(
toks = warmup_inference(
model=cast(Model, inference_model),
tokenizer=tokenizer,
group=group,
)
logger.info(
f"warmed up by generating {toks} tokens, {bytes_per_token} bytes/token for KV cache"
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
@@ -314,7 +310,6 @@ def main(
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
group=group,
bytes_per_token=bytes_per_token,
)
if tokenizer.has_thinking:
@@ -341,7 +336,6 @@ def main(
completion_tokens = 0
tokens_since_last_cancel_check = check_for_cancel_every
oom_stopped = False
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
@@ -350,14 +344,7 @@ def main(
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
oom_local = (
bytes_per_token > 0
and psutil.virtual_memory().percent / 100
> get_memory_pressure_threshold()
)
if mx_any(want_to_cancel or oom_local, group):
if not want_to_cancel:
oom_stopped = True
if mx_any(want_to_cancel, group):
break
match response:
@@ -413,21 +400,6 @@ def main(
)
)
if oom_stopped and device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=model_id,
error_message=(
"Generation stopped: running out of memory. "
"Please start a new conversation or compact "
"your messages."
),
),
)
)
except PrefillCancelled:
logger.info(f"Prefill cancelled for task {task.task_id}")
# can we make this more explicit?

View File

@@ -114,7 +114,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin((1, 0)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).