mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-16 01:51:03 -05:00
Compare commits
1 Commits
main
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d91ab9456e |
11
AGENTS.md
11
AGENTS.md
@@ -30,14 +30,17 @@ uv run pytest src/exo/shared/tests/test_election.py
|
||||
# Run a specific test function
|
||||
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
||||
|
||||
# Type checking (strict mode)
|
||||
uv run basedpyright
|
||||
# Type checking (strict mode) - MUST pass before committing
|
||||
uv run basedpyright --project pyproject.toml
|
||||
|
||||
# Linting
|
||||
uv run ruff check
|
||||
|
||||
# Format code (using nix)
|
||||
nix fmt
|
||||
|
||||
# Run all checks (do this before committing)
|
||||
uv run basedpyright --project pyproject.toml && uv run ruff check && nix fmt
|
||||
```
|
||||
|
||||
## Architecture
|
||||
@@ -91,6 +94,10 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## File Locations
|
||||
|
||||
- **Downloaded models**: `~/.exo/models/` (NOT in huggingface cache)
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
@@ -1,104 +1,189 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
"""KV prefix cache for reusing computed prompt prefixes across requests."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
import numpy as np
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Type alias for KV cache - the actual type is _BaseCache but it's private
|
||||
KVCacheType = Any
|
||||
|
||||
|
||||
class TokenizerProtocol(Protocol):
|
||||
"""Protocol for tokenizers used with KVPrefixCache."""
|
||||
|
||||
bos_token: str | None
|
||||
|
||||
def encode(self, text: str, **kwargs: bool) -> list[int]: ...
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
"""Cache for common prompt prefixes to avoid re-processing.
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
Uses LRU eviction when capacity is reached. Stores tokenized prompts
|
||||
and their corresponding KV caches for reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 10):
|
||||
"""Initialize prefix cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of cached entries before LRU eviction.
|
||||
"""
|
||||
self.max_size = max_size
|
||||
# OrderedDict maintains insertion order for LRU - most recent at end
|
||||
# Key: token bytes, Value: (tokens as mx.array, KV cache)
|
||||
self._cache: OrderedDict[bytes, tuple[mx.array, list[KVCacheType]]] = (
|
||||
OrderedDict()
|
||||
)
|
||||
|
||||
def _token_key(self, tokens: mx.array) -> bytes:
|
||||
"""Create hashable key from token array."""
|
||||
return np.array(tokens.tolist(), dtype=np.int32).tobytes()
|
||||
|
||||
def _encode_prompt(self, tokenizer: TokenizerProtocol, prompt: str) -> mx.array:
|
||||
"""Tokenize prompt string to mx.array."""
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
return mx.array(tokenized)
|
||||
|
||||
def _find_best_prefix(self, tokens: mx.array) -> tuple[bytes | None, int]:
|
||||
"""Find cached entry with longest matching prefix.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache_key, prefix_length). cache_key is None if no match found.
|
||||
"""
|
||||
best_key: bytes | None = None
|
||||
best_length = 0
|
||||
target_len = tokens.shape[0]
|
||||
|
||||
for key, (cached_tokens, _cache) in self._cache.items():
|
||||
prefix_len = get_prefix_length(tokens, cached_tokens)
|
||||
|
||||
# Exact match - return immediately
|
||||
if prefix_len == target_len and prefix_len == cached_tokens.shape[0]:
|
||||
return key, prefix_len
|
||||
|
||||
# Better prefix match
|
||||
if prefix_len > best_length:
|
||||
best_key = key
|
||||
best_length = prefix_len
|
||||
|
||||
return best_key, best_length
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tokenizer: TokenizerProtocol,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
) -> tuple[list[KVCacheType], int]:
|
||||
"""Get KV cache for prompt, reusing prefix if available.
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
Args:
|
||||
model: The model to create cache for.
|
||||
tokenizer: Tokenizer for encoding prompt.
|
||||
sampler: Sampler function for prefill.
|
||||
prompt: The prompt string to process.
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
Returns:
|
||||
Tuple of (kv_cache, tokens_reused). tokens_reused indicates how many
|
||||
tokens were reused from cache (0 if no cache hit).
|
||||
"""
|
||||
tokens = self._encode_prompt(tokenizer, prompt)
|
||||
target_len = int(tokens.shape[0])
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
# Find best prefix match
|
||||
best_key, prefix_len = self._find_best_prefix(tokens)
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
if best_key is not None and prefix_len > 0:
|
||||
cached_tokens, cached_kv = self._cache[best_key]
|
||||
cached_len = int(cached_tokens.shape[0])
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(best_key)
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
if prefix_len == target_len and prefix_len == cached_len:
|
||||
# Exact match - return deepcopy directly
|
||||
logger.debug(f"Prefix cache: exact match, reusing {prefix_len} tokens")
|
||||
return deepcopy(cached_kv), prefix_len
|
||||
|
||||
# Partial match - need to trim and/or extend
|
||||
prompt_cache = deepcopy(cached_kv)
|
||||
|
||||
if cached_len > prefix_len:
|
||||
# Cached prompt is longer - trim to prefix length
|
||||
num_to_trim = cached_len - prefix_len
|
||||
trim_prompt_cache(prompt_cache, num_to_trim)
|
||||
logger.debug(
|
||||
f"Prefix cache: trimmed {num_to_trim} tokens from cached entry"
|
||||
)
|
||||
|
||||
# Note: We don't prefill remaining tokens here - stream_generate will do it
|
||||
# when processing the full prompt with this partial cache
|
||||
return prompt_cache, prefix_len
|
||||
|
||||
# No cache hit - return fresh cache (stream_generate will prefill)
|
||||
logger.debug(
|
||||
f"Prefix cache: miss, will prefill {target_len} tokens during generation"
|
||||
)
|
||||
prompt_cache = make_kv_cache(model=model)
|
||||
|
||||
return prompt_cache, 0
|
||||
|
||||
def put(
|
||||
self, tokenizer: TokenizerProtocol, prompt: str, cache: list[KVCacheType]
|
||||
) -> None:
|
||||
"""Store KV cache for prompt after generation completes.
|
||||
|
||||
Args:
|
||||
tokenizer: Tokenizer for encoding prompt.
|
||||
prompt: The prompt string that was processed.
|
||||
cache: The KV cache to store.
|
||||
"""
|
||||
tokens = self._encode_prompt(tokenizer, prompt)
|
||||
key = self._token_key(tokens)
|
||||
|
||||
# If already in cache, just move to end
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
return
|
||||
|
||||
# Evict LRU entry if at capacity
|
||||
if len(self._cache) >= self.max_size:
|
||||
evicted_key, _ = self._cache.popitem(last=False)
|
||||
logger.debug(
|
||||
f"Prefix cache: evicted LRU entry ({len(evicted_key)} token bytes)"
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
# Store deepcopy
|
||||
self._cache[key] = (tokens, deepcopy(cache))
|
||||
logger.debug(f"Prefix cache: stored entry with {tokens.shape[0]} tokens")
|
||||
|
||||
return prompt_cache
|
||||
def clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
self._cache.clear()
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
def __len__(self) -> int:
|
||||
"""Return number of cached entries."""
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Calculate length of matching prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -6,7 +6,6 @@ from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
@@ -19,6 +18,7 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -119,6 +119,7 @@ def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -135,8 +136,6 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -148,6 +147,20 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Get KV cache - either from prefix cache or fresh
|
||||
tokens_reused = 0
|
||||
if prefix_cache is not None:
|
||||
caches, tokens_reused = prefix_cache.get_kv_cache(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
prompt=prompt,
|
||||
)
|
||||
if tokens_reused > 0:
|
||||
logger.info(f"Prefix cache hit: reused {tokens_reused} tokens")
|
||||
else:
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
@@ -189,6 +202,9 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Store in prefix cache for future reuse
|
||||
if prefix_cache is not None:
|
||||
prefix_cache.put(tokenizer=tokenizer, prompt=prompt, cache=caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -39,6 +39,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
@@ -69,6 +70,7 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
prefix_cache: KVPrefixCache | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -110,6 +112,8 @@ def main(
|
||||
)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
prefix_cache = KVPrefixCache(max_size=10)
|
||||
logger.info("prefix cache initialized")
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -157,6 +161,7 @@ def main(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prefix_cache=prefix_cache,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
|
||||
141
src/exo/worker/tests/unittests/test_mlx/test_prefix_cache.py
Normal file
141
src/exo/worker/tests/unittests/test_mlx/test_prefix_cache.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Tests for KVPrefixCache."""
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVCacheType,
|
||||
KVPrefixCache,
|
||||
TokenizerProtocol,
|
||||
get_prefix_length,
|
||||
)
|
||||
|
||||
|
||||
class MockTokenizer(TokenizerProtocol):
|
||||
"""Mock tokenizer that converts string to list of char codes."""
|
||||
|
||||
bos_token: str | None = None
|
||||
|
||||
def encode(self, text: str, **kwargs: bool) -> list[int]:
|
||||
"""Encode text to list of character codes."""
|
||||
del kwargs # unused
|
||||
return [ord(c) for c in text]
|
||||
|
||||
|
||||
class TestGetPrefixLength:
|
||||
"""Tests for the core prefix matching algorithm."""
|
||||
|
||||
def test_identical_arrays(self) -> None:
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 5
|
||||
|
||||
def test_partial_match(self) -> None:
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 9, 9])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([9, 9, 9])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_different_lengths(self) -> None:
|
||||
short = mx.array([1, 2, 3])
|
||||
long = mx.array([1, 2, 3, 4, 5])
|
||||
# Should return length of shorter when they match
|
||||
assert get_prefix_length(short, long) == 3
|
||||
assert get_prefix_length(long, short) == 3
|
||||
|
||||
def test_empty_array(self) -> None:
|
||||
empty: mx.array = mx.array([])
|
||||
tokens = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(empty, tokens) == 0
|
||||
assert get_prefix_length(tokens, empty) == 0
|
||||
|
||||
|
||||
class TestKVPrefixCache:
|
||||
"""Tests for the KV prefix cache."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self) -> MockTokenizer:
|
||||
"""Mock tokenizer that converts string to list of char codes."""
|
||||
return MockTokenizer()
|
||||
|
||||
@pytest.fixture
|
||||
def fake_kv(self) -> list[KVCacheType]:
|
||||
"""Fake KV cache for testing."""
|
||||
return [object()]
|
||||
|
||||
def test_put_stores_entry(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=10)
|
||||
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
|
||||
assert len(cache) == 1
|
||||
|
||||
def test_put_same_prompt_twice_does_not_duplicate(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=10)
|
||||
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
|
||||
assert len(cache) == 1
|
||||
|
||||
def test_lru_eviction(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=2)
|
||||
|
||||
# Fill cache
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
cache.put(tokenizer, "second", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
# Add third - should evict "first" (oldest)
|
||||
cache.put(tokenizer, "third", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
# Add "first" again - if it was evicted, cache size stays 2
|
||||
# If it wasn't evicted, this would be a no-op
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
# Now add fourth - if "first" was re-added, size is still 2
|
||||
cache.put(tokenizer, "fourth", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
def test_lru_access_refreshes_entry(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=2)
|
||||
|
||||
# Add two entries
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
cache.put(tokenizer, "second", fake_kv)
|
||||
|
||||
# Access "first" again (moves to end of LRU)
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
|
||||
# Add third - should evict "second" now (oldest)
|
||||
cache.put(tokenizer, "third", fake_kv)
|
||||
|
||||
# Add "second" again - this will add it as new entry
|
||||
cache.put(tokenizer, "second", fake_kv)
|
||||
# Now "first" is oldest, adding fourth should evict it
|
||||
cache.put(tokenizer, "fourth", fake_kv)
|
||||
|
||||
# Cache should have "second" and "fourth", not "first"
|
||||
assert len(cache) == 2
|
||||
|
||||
def test_clear(self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]) -> None:
|
||||
cache = KVPrefixCache()
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
cache.put(tokenizer, "world", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert len(cache) == 0
|
||||
Reference in New Issue
Block a user