Compare commits

...

1 Commits

Author SHA1 Message Date
Alex Cheema
d91ab9456e feat: add KV prefix caching for prompt reuse
Add LRU-based KV prefix cache to reuse computed prompt prefixes across
requests. When multiple requests share a common prefix (e.g., system prompt),
the cached KV state is reused instead of recomputing it.

Changes:
- Add KVPrefixCache class with LRU eviction in cache.py
- Integrate prefix cache into mlx_generate in generate.py
- Create cache instance in runner.py
- Add comprehensive tests in test_prefix_cache.py
- Update AGENTS.md with full type check command

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 19:02:26 +00:00
5 changed files with 330 additions and 76 deletions

View File

@@ -30,14 +30,17 @@ uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Type checking (strict mode) - MUST pass before committing
uv run basedpyright --project pyproject.toml
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
# Run all checks (do this before committing)
uv run basedpyright --project pyproject.toml && uv run ruff check && nix fmt
```
## Architecture
@@ -91,6 +94,10 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## File Locations
- **Downloaded models**: `~/.exo/models/` (NOT in huggingface cache)
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -1,104 +1,189 @@
# type: ignore
# TODO: Fix this file, including types!
"""KV prefix cache for reusing computed prompt prefixes across requests."""
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
from typing import Any, Callable, Protocol
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
import numpy as np
from mlx_lm.models.cache import trim_prompt_cache
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.runner.bootstrap import logger
# Type alias for KV cache - the actual type is _BaseCache but it's private
KVCacheType = Any
class TokenizerProtocol(Protocol):
"""Protocol for tokenizers used with KVPrefixCache."""
bos_token: str | None
def encode(self, text: str, **kwargs: bool) -> list[int]: ...
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
"""Cache for common prompt prefixes to avoid re-processing.
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
Uses LRU eviction when capacity is reached. Stores tokenized prompts
and their corresponding KV caches for reuse.
"""
def __init__(self, max_size: int = 10):
"""Initialize prefix cache.
Args:
max_size: Maximum number of cached entries before LRU eviction.
"""
self.max_size = max_size
# OrderedDict maintains insertion order for LRU - most recent at end
# Key: token bytes, Value: (tokens as mx.array, KV cache)
self._cache: OrderedDict[bytes, tuple[mx.array, list[KVCacheType]]] = (
OrderedDict()
)
def _token_key(self, tokens: mx.array) -> bytes:
"""Create hashable key from token array."""
return np.array(tokens.tolist(), dtype=np.int32).tobytes()
def _encode_prompt(self, tokenizer: TokenizerProtocol, prompt: str) -> mx.array:
"""Tokenize prompt string to mx.array."""
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
return mx.array(tokenized)
def _find_best_prefix(self, tokens: mx.array) -> tuple[bytes | None, int]:
"""Find cached entry with longest matching prefix.
Returns:
Tuple of (cache_key, prefix_length). cache_key is None if no match found.
"""
best_key: bytes | None = None
best_length = 0
target_len = tokens.shape[0]
for key, (cached_tokens, _cache) in self._cache.items():
prefix_len = get_prefix_length(tokens, cached_tokens)
# Exact match - return immediately
if prefix_len == target_len and prefix_len == cached_tokens.shape[0]:
return key, prefix_len
# Better prefix match
if prefix_len > best_length:
best_key = key
best_length = prefix_len
return best_key, best_length
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
tokenizer: TokenizerProtocol,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
) -> tuple[list[KVCacheType], int]:
"""Get KV cache for prompt, reusing prefix if available.
best_snapshot_index, best_snapshot_length = None, 0
Args:
model: The model to create cache for.
tokenizer: Tokenizer for encoding prompt.
sampler: Sampler function for prefill.
prompt: The prompt string to process.
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
Returns:
Tuple of (kv_cache, tokens_reused). tokens_reused indicates how many
tokens were reused from cache (0 if no cache hit).
"""
tokens = self._encode_prompt(tokenizer, prompt)
target_len = int(tokens.shape[0])
if length == max_length:
return self.caches[i]
# Find best prefix match
best_key, prefix_len = self._find_best_prefix(tokens)
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_key is not None and prefix_len > 0:
cached_tokens, cached_kv = self._cache[best_key]
cached_len = int(cached_tokens.shape[0])
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
# Move to end (most recently used)
self._cache.move_to_end(best_key)
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
if prefix_len == target_len and prefix_len == cached_len:
# Exact match - return deepcopy directly
logger.debug(f"Prefix cache: exact match, reusing {prefix_len} tokens")
return deepcopy(cached_kv), prefix_len
# Partial match - need to trim and/or extend
prompt_cache = deepcopy(cached_kv)
if cached_len > prefix_len:
# Cached prompt is longer - trim to prefix length
num_to_trim = cached_len - prefix_len
trim_prompt_cache(prompt_cache, num_to_trim)
logger.debug(
f"Prefix cache: trimmed {num_to_trim} tokens from cached entry"
)
# Note: We don't prefill remaining tokens here - stream_generate will do it
# when processing the full prompt with this partial cache
return prompt_cache, prefix_len
# No cache hit - return fresh cache (stream_generate will prefill)
logger.debug(
f"Prefix cache: miss, will prefill {target_len} tokens during generation"
)
prompt_cache = make_kv_cache(model=model)
return prompt_cache, 0
def put(
self, tokenizer: TokenizerProtocol, prompt: str, cache: list[KVCacheType]
) -> None:
"""Store KV cache for prompt after generation completes.
Args:
tokenizer: Tokenizer for encoding prompt.
prompt: The prompt string that was processed.
cache: The KV cache to store.
"""
tokens = self._encode_prompt(tokenizer, prompt)
key = self._token_key(tokens)
# If already in cache, just move to end
if key in self._cache:
self._cache.move_to_end(key)
return
# Evict LRU entry if at capacity
if len(self._cache) >= self.max_size:
evicted_key, _ = self._cache.popitem(last=False)
logger.debug(
f"Prefix cache: evicted LRU entry ({len(evicted_key)} token bytes)"
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
# Store deepcopy
self._cache[key] = (tokens, deepcopy(cache))
logger.debug(f"Prefix cache: stored entry with {tokens.shape[0]} tokens")
return prompt_cache
def clear(self) -> None:
"""Clear all cached entries."""
self._cache.clear()
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
def __len__(self) -> int:
"""Return number of cached entries."""
return len(self._cache)
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Calculate length of matching prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -6,7 +6,6 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -19,6 +18,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -119,6 +119,7 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -135,8 +136,6 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -148,6 +147,20 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Get KV cache - either from prefix cache or fresh
tokens_reused = 0
if prefix_cache is not None:
caches, tokens_reused = prefix_cache.get_kv_cache(
model=model,
tokenizer=tokenizer,
sampler=sampler,
prompt=prompt,
)
if tokens_reused > 0:
logger.info(f"Prefix cache hit: reused {tokens_reused} tokens")
else:
caches = make_kv_cache(model=model)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
@@ -189,6 +202,9 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Store in prefix cache for future reuse
if prefix_cache is not None:
prefix_cache.put(tokenizer=tokenizer, prompt=prompt, cache=caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -39,6 +39,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -69,6 +70,7 @@ def main(
model = None
tokenizer = None
group = None
prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -110,6 +112,8 @@ def main(
)
model, tokenizer = load_mlx_items(bound_instance, group)
prefix_cache = KVPrefixCache(max_size=10)
logger.info("prefix cache initialized")
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -157,6 +161,7 @@ def main(
model=model,
tokenizer=tokenizer,
task=task_params,
prefix_cache=prefix_cache,
):
match response:
case GenerationResponse():

View File

@@ -0,0 +1,141 @@
"""Tests for KVPrefixCache."""
import mlx.core as mx
import pytest
from exo.worker.engines.mlx.cache import (
KVCacheType,
KVPrefixCache,
TokenizerProtocol,
get_prefix_length,
)
class MockTokenizer(TokenizerProtocol):
"""Mock tokenizer that converts string to list of char codes."""
bos_token: str | None = None
def encode(self, text: str, **kwargs: bool) -> list[int]:
"""Encode text to list of character codes."""
del kwargs # unused
return [ord(c) for c in text]
class TestGetPrefixLength:
"""Tests for the core prefix matching algorithm."""
def test_identical_arrays(self) -> None:
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 5
def test_partial_match(self) -> None:
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 9, 9])
assert get_prefix_length(a, b) == 3
def test_no_match(self) -> None:
a = mx.array([1, 2, 3])
b = mx.array([9, 9, 9])
assert get_prefix_length(a, b) == 0
def test_different_lengths(self) -> None:
short = mx.array([1, 2, 3])
long = mx.array([1, 2, 3, 4, 5])
# Should return length of shorter when they match
assert get_prefix_length(short, long) == 3
assert get_prefix_length(long, short) == 3
def test_empty_array(self) -> None:
empty: mx.array = mx.array([])
tokens = mx.array([1, 2, 3])
assert get_prefix_length(empty, tokens) == 0
assert get_prefix_length(tokens, empty) == 0
class TestKVPrefixCache:
"""Tests for the KV prefix cache."""
@pytest.fixture
def tokenizer(self) -> MockTokenizer:
"""Mock tokenizer that converts string to list of char codes."""
return MockTokenizer()
@pytest.fixture
def fake_kv(self) -> list[KVCacheType]:
"""Fake KV cache for testing."""
return [object()]
def test_put_stores_entry(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=10)
cache.put(tokenizer, "hello", fake_kv)
assert len(cache) == 1
def test_put_same_prompt_twice_does_not_duplicate(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=10)
cache.put(tokenizer, "hello", fake_kv)
cache.put(tokenizer, "hello", fake_kv)
assert len(cache) == 1
def test_lru_eviction(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=2)
# Fill cache
cache.put(tokenizer, "first", fake_kv)
cache.put(tokenizer, "second", fake_kv)
assert len(cache) == 2
# Add third - should evict "first" (oldest)
cache.put(tokenizer, "third", fake_kv)
assert len(cache) == 2
# Add "first" again - if it was evicted, cache size stays 2
# If it wasn't evicted, this would be a no-op
cache.put(tokenizer, "first", fake_kv)
# Now add fourth - if "first" was re-added, size is still 2
cache.put(tokenizer, "fourth", fake_kv)
assert len(cache) == 2
def test_lru_access_refreshes_entry(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=2)
# Add two entries
cache.put(tokenizer, "first", fake_kv)
cache.put(tokenizer, "second", fake_kv)
# Access "first" again (moves to end of LRU)
cache.put(tokenizer, "first", fake_kv)
# Add third - should evict "second" now (oldest)
cache.put(tokenizer, "third", fake_kv)
# Add "second" again - this will add it as new entry
cache.put(tokenizer, "second", fake_kv)
# Now "first" is oldest, adding fourth should evict it
cache.put(tokenizer, "fourth", fake_kv)
# Cache should have "second" and "fourth", not "first"
assert len(cache) == 2
def test_clear(self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]) -> None:
cache = KVPrefixCache()
cache.put(tokenizer, "hello", fake_kv)
cache.put(tokenizer, "world", fake_kv)
assert len(cache) == 2
cache.clear()
assert len(cache) == 0