Compare commits

...

7 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
9c320d7757 Test LRU eviction 2026-01-23 20:43:51 +00:00
Ryuichi Leo Takashige
424d96c6ac Remove incorrect typing 2026-01-23 20:36:50 +00:00
Ryuichi Leo Takashige
2d42af8477 Add tests 2026-01-23 19:50:36 +00:00
Ryuichi Leo Takashige
a02b452e24 Try and limit memory consumption 2026-01-23 19:50:30 +00:00
Ryuichi Leo Takashige
7744420341 cleanup 2026-01-23 16:32:58 +00:00
Ryuichi Leo Takashige
b777c6f505 Merge remote-tracking branch 'origin/main' into fix-kv-prefix-cache
# Conflicts:
#	.mlx_typings/mlx_lm/tokenizer_utils.pyi
#	src/exo/worker/engines/mlx/generator/generate.py
#	src/exo/worker/runner/runner.py
2026-01-23 16:11:26 +00:00
David Hind
812a9f232e Fix KV prefix cache for prompt reuse
- Wire up KVPrefixCache to runner and generate
- Fix exact match to return deepcopy (was returning reference)
- Fix trim_prompt_cache argument (was using wrong calculation)
- Fix token slicing to use best_snapshot_length (not index)
- Add _cache_length() using .offset for compatibility with older mlx_lm
- Fix prefill() to use max_tokens=1 with trim (workaround for mlx_lm bug)
- Add clear() method for single-cache behavior
- Remove KEEP_KV_SIZE limit from prefix matching
- Add minimal logging for cache hits/misses

Fix type errors and KV cache implementation

Type fixes for CI:
- Add KVCacheType alias matching make_kv_cache return type
- Update function signatures to use consistent cache types
- Add explicit type annotations

KV cache fixes to actually reduce TTFT:
- get_kv_cache now prefills internally and returns only last token
- stream_generate receives 1 token on cache hit instead of full prompt
- Extract encode_prompt as standalone function for reuse

Refactor KV cache: move prefill to generate.py, add shared KVCacheType

Address PR feedback:
- Move KVCacheType to shared/types/mlx.py for reuse across codebase
- Move prefill logic from cache.py to generate.py
- get_kv_cache now only returns cache + remaining tokens (no prefill)
- Caller (mlx_generate) is responsible for prefilling

Fix types: regenerate mlx stubs, remove type ignores

- Regenerate cache.pyi and tokenizer_utils.pyi stubs for latest mlx_lm
- Remove # type: ignore from cache.py (now fully typed)
- Remove unnecessary type ignores from generate.py
- Use mx.equal() instead of == for proper array typing

Fix encode_prompt to not add special tokens for chat-templated prompts

Chat templates (like Kimi-K2's <|im_user|>, <|im_middle|>, etc.) already
include their own structure markers. Adding BOS/EOS tokens on top of this
corrupts the prompt structure and can slow down prefill.

Use add_special_tokens=False since the chat template defines its own structure.

Add prefill logging with progress callbacks and timing stats
2026-01-23 15:38:28 +00:00
6 changed files with 794 additions and 75 deletions

View File

@@ -0,0 +1,11 @@
"""Shared types for MLX-related functionality."""
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# Type alias for KV cache - matches make_kv_cache return type
# This list contains one cache entry per transformer layer
KVCacheType = list[KVCache | RotatingKVCache | QuantizedKVCache]

View File

@@ -1,39 +1,74 @@
# type: ignore
# TODO: Fix this file, including types!
from copy import deepcopy
from typing import Callable
from typing import Any, cast
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_MEMORY_PRESSURE_THRESHOLD = 0.85
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self._last_used.clear()
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
self, tokenizer: TokenizerWrapper, prompt: str, cache: KVCacheType
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
tokenized_prompt = encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
def update_kv_cache(
self,
index: int,
tokenizer: TokenizerWrapper,
prompt: str,
cache: KVCacheType,
):
"""Update an existing cache entry in-place."""
tokenized_prompt = encode_prompt(tokenizer, prompt)
self.prompts[index] = tokenized_prompt
self.caches[index] = deepcopy(cache)
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
) -> tuple[KVCacheType, mx.array, int | None]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
Returns:
Tuple of (cache, remaining_tokens, matched_index) where:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
"""
tokenized_prompt = encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0
@@ -42,63 +77,102 @@ class KVPrefixCache:
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
return self.caches[i]
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
self._access_counter += 1
self._last_used[i] = self._access_counter
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:], i
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
return prompt_cache
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
self._access_counter += 1
self._last_used[best_snapshot_index] = self._access_counter
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens, best_snapshot_index
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory pressure is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_PRESSURE_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while len(self.caches) > 0:
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
)
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_PRESSURE_THRESHOLD:
break
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
For chat-templated prompts (which have their own structure markers like
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -4,7 +4,7 @@
KV_GROUP_SIZE: int | None = 32
KV_BITS: int | None = None
ATTENTION_KV_BITS: int | None = 4
MAX_TOKENS: int = 8192
MAX_TOKENS: int = 32168
MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"

View File

@@ -1,12 +1,12 @@
import time
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -14,11 +14,13 @@ from exo.shared.types.api import (
GenerationStats,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -29,20 +31,62 @@ from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
def maybe_quantize_kv_cache(
prompt_cache: list[KVCache | Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
then trims off the extra generated token.
Returns:
tokens_per_sec
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cast(list[Any], cache), 1)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
def warmup_inference(
@@ -120,6 +164,7 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -131,7 +176,22 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
caches = make_kv_cache(model=model)
# Do not use the prefix cache if we are trying to do benchmarks.
if is_bench:
kv_prefix_cache = None
# Use prefix cache if available, otherwise create fresh cache
prefix_hit_length = 0
matched_index: int | None = None
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
else:
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
all_prompt_tokens = encode_prompt(tokenizer, prompt)
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -144,11 +204,19 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
@@ -158,12 +226,13 @@ def mlx_generate(
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
generated_text_parts.append(out.text)
logger.info(out.text)
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
@@ -185,6 +254,28 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
generation_tps = (
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
)
logger.debug(
f"Generation complete: prefill {prompt_tokens} tokens @ "
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
f"{generation_tps:.1f} tok/s"
)
if kv_prefix_cache is not None:
full_prompt = prompt + "".join(generated_text_parts)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(
matched_index, tokenizer, full_prompt, caches
)
else:
kv_prefix_cache.add_kv_cache(tokenizer, full_prompt, caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -70,6 +70,7 @@ from exo.worker.engines.image import (
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -103,6 +104,7 @@ def main(
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -171,6 +173,9 @@ def main(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
kv_prefix_cache = KVPrefixCache()
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -238,6 +243,7 @@ def main(
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
)
# GPT-OSS specific parsing to match other model formats.

View File

@@ -0,0 +1,537 @@
# type: ignore
import time
from typing import cast
from unittest.mock import patch
import mlx.core as mx
import pytest
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
_cache_length,
_get_prefix_length,
encode_prompt,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, make_kv_cache
from exo.worker.tests.unittests.test_mlx.conftest import (
DEFAULT_GPT_OSS_CONFIG,
DEFAULT_GPT_OSS_MODEL_ID,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert _get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert _get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert _get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
class TestKVPrefix:
def test_starts_empty(self):
cache = KVPrefixCache()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self):
cache = KVPrefixCache()
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self):
cache = KVPrefixCache()
cache.clear()
assert len(cache.prompts) == 0
def _load_gpt_oss() -> tuple[Model, object]:
from mlx_lm.utils import load_model
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
model, _ = load_model(model_path, lazy=False)
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
return cast(Model, model), tokenizer
@pytest.mark.slow
@pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
)
class TestKVPrefixCacheWithModel:
@pytest.fixture(scope="class")
def model_and_tokenizer(self):
model, tokenizer = _load_gpt_oss()
return model, tokenizer
def test_prefill_populates_cache(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert _cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Test exact")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = _cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
model, tokenizer = model_and_tokenizer
short_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hi")],
max_tokens=1,
)
short_prompt = apply_chat_template(tokenizer, short_task)
short_tokens = encode_prompt(tokenizer, short_prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, short_prompt, cache)
# Query with longer prompt that shares the chat template prefix
long_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content="Hi there, how are you?")
],
max_tokens=1,
)
long_prompt = apply_chat_template(tokenizer, long_task)
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, long_prompt
)
assert matched_index == 0
# remaining_tokens should be the suffix after the shared prefix
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
def test_stored_cache_not_mutated_after_get_and_generation(
self, model_and_tokenizer
):
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
assert matched_index == 0
# Simulate generation: feed many additional tokens through the cache
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra_keys, extra_values)
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
):
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokenizer, prompt)
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra = mx.random.normal((1, num_heads, 30, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# Consume the entire generator so the cache-saving code after yield runs
generated_tokens = 0
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
generated_tokens += 1
assert len(kv_prefix_cache.prompts) == 1
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
assert len(kv_prefix_cache.prompts) == 1
# Second call should find a prefix match (the stored cache contains
# prompt + generated tokens, which shares the prompt prefix)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
# The stored cache is longer than the prompt (it includes generated tokens),
# so this is a prefix match where our prompt is fully contained
assert matched_index == 0
# Exact match: remaining_tokens is just the last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
repeats = (1200 // len(base_tokens)) + 2
long_content = base_text * repeats
task1 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=long_content)],
max_tokens=5,
)
prompt1 = apply_chat_template(tokenizer, task1)
prompt1_tokens = encode_prompt(tokenizer, prompt1)
assert len(prompt1_tokens) > 1000, (
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
)
# First generation populates the cache (must prefill all tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content=long_content),
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
ChatCompletionMessage(role="user", content="Tell me more."),
],
max_tokens=5,
)
prompt2 = apply_chat_template(tokenizer, task2)
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
):
pass
second_gen_time = time.perf_counter() - t0
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
assert second_gen_time < first_gen_time * 0.5, (
f"Expected prefix cache speedup: "
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
)
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
# The first stored cache must not have been mutated by the second generation
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
for i, content in enumerate(prompts):
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=content)],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
assert len(kv_prefix_cache.prompts) == 3
# Access the third entry to make it most recently used
kv_prefix_cache._last_used[2] = 100.0
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
# Simulate memory pressure: active memory exceeds threshold
fake_limit = 1000
fake_active = int(fake_limit * 0.90) # Above _MEMORY_PRESSURE_THRESHOLD (0.85)
with (
patch(
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
return_value=fake_active,
),
patch(
"exo.worker.engines.mlx.cache.mx.metal.device_info",
return_value={"max_recommended_working_set_size": fake_limit},
),
):
# Trigger eviction by adding a new entry
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="New entry")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
# Since fake_active stays above threshold after each eviction (we don't change it),
# all old entries get evicted, leaving only the newly added one
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)