mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
1 Commits
remove-pyt
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d91ab9456e |
11
AGENTS.md
11
AGENTS.md
@@ -30,14 +30,17 @@ uv run pytest src/exo/shared/tests/test_election.py
|
|||||||
# Run a specific test function
|
# Run a specific test function
|
||||||
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
||||||
|
|
||||||
# Type checking (strict mode)
|
# Type checking (strict mode) - MUST pass before committing
|
||||||
uv run basedpyright
|
uv run basedpyright --project pyproject.toml
|
||||||
|
|
||||||
# Linting
|
# Linting
|
||||||
uv run ruff check
|
uv run ruff check
|
||||||
|
|
||||||
# Format code (using nix)
|
# Format code (using nix)
|
||||||
nix fmt
|
nix fmt
|
||||||
|
|
||||||
|
# Run all checks (do this before committing)
|
||||||
|
uv run basedpyright --project pyproject.toml && uv run ruff check && nix fmt
|
||||||
```
|
```
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
@@ -91,6 +94,10 @@ From .cursorrules:
|
|||||||
- Catch exceptions only where you can handle them meaningfully
|
- Catch exceptions only where you can handle them meaningfully
|
||||||
- Use `@final` and immutability wherever applicable
|
- Use `@final` and immutability wherever applicable
|
||||||
|
|
||||||
|
## File Locations
|
||||||
|
|
||||||
|
- **Downloaded models**: `~/.exo/models/` (NOT in huggingface cache)
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||||
|
|||||||
@@ -1,104 +1,189 @@
|
|||||||
# type: ignore
|
"""KV prefix cache for reusing computed prompt prefixes across requests."""
|
||||||
# TODO: Fix this file, including types!
|
|
||||||
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable
|
from typing import Any, Callable, Protocol
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm import stream_generate
|
import numpy as np
|
||||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
from mlx_lm.models.cache import trim_prompt_cache
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
|
||||||
|
|
||||||
from exo.worker.engines.mlx import Model
|
from exo.worker.engines.mlx import Model
|
||||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
|
||||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||||
|
from exo.worker.runner.bootstrap import logger
|
||||||
|
|
||||||
|
# Type alias for KV cache - the actual type is _BaseCache but it's private
|
||||||
|
KVCacheType = Any
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerProtocol(Protocol):
|
||||||
|
"""Protocol for tokenizers used with KVPrefixCache."""
|
||||||
|
|
||||||
|
bos_token: str | None
|
||||||
|
|
||||||
|
def encode(self, text: str, **kwargs: bool) -> list[int]: ...
|
||||||
|
|
||||||
|
|
||||||
class KVPrefixCache:
|
class KVPrefixCache:
|
||||||
def __init__(self):
|
"""Cache for common prompt prefixes to avoid re-processing.
|
||||||
# Only one prefix cache per runner.
|
|
||||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
|
||||||
self.caches: list[list[_BaseCache]] = []
|
|
||||||
|
|
||||||
def add_kv_cache(
|
Uses LRU eviction when capacity is reached. Stores tokenized prompts
|
||||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
and their corresponding KV caches for reuse.
|
||||||
):
|
"""
|
||||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
|
||||||
self.prompts.append(tokenized_prompt)
|
def __init__(self, max_size: int = 10):
|
||||||
self.caches.append(deepcopy(cache))
|
"""Initialize prefix cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of cached entries before LRU eviction.
|
||||||
|
"""
|
||||||
|
self.max_size = max_size
|
||||||
|
# OrderedDict maintains insertion order for LRU - most recent at end
|
||||||
|
# Key: token bytes, Value: (tokens as mx.array, KV cache)
|
||||||
|
self._cache: OrderedDict[bytes, tuple[mx.array, list[KVCacheType]]] = (
|
||||||
|
OrderedDict()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _token_key(self, tokens: mx.array) -> bytes:
|
||||||
|
"""Create hashable key from token array."""
|
||||||
|
return np.array(tokens.tolist(), dtype=np.int32).tobytes()
|
||||||
|
|
||||||
|
def _encode_prompt(self, tokenizer: TokenizerProtocol, prompt: str) -> mx.array:
|
||||||
|
"""Tokenize prompt string to mx.array."""
|
||||||
|
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||||
|
tokenizer.bos_token
|
||||||
|
)
|
||||||
|
tokenized = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||||
|
return mx.array(tokenized)
|
||||||
|
|
||||||
|
def _find_best_prefix(self, tokens: mx.array) -> tuple[bytes | None, int]:
|
||||||
|
"""Find cached entry with longest matching prefix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (cache_key, prefix_length). cache_key is None if no match found.
|
||||||
|
"""
|
||||||
|
best_key: bytes | None = None
|
||||||
|
best_length = 0
|
||||||
|
target_len = tokens.shape[0]
|
||||||
|
|
||||||
|
for key, (cached_tokens, _cache) in self._cache.items():
|
||||||
|
prefix_len = get_prefix_length(tokens, cached_tokens)
|
||||||
|
|
||||||
|
# Exact match - return immediately
|
||||||
|
if prefix_len == target_len and prefix_len == cached_tokens.shape[0]:
|
||||||
|
return key, prefix_len
|
||||||
|
|
||||||
|
# Better prefix match
|
||||||
|
if prefix_len > best_length:
|
||||||
|
best_key = key
|
||||||
|
best_length = prefix_len
|
||||||
|
|
||||||
|
return best_key, best_length
|
||||||
|
|
||||||
def get_kv_cache(
|
def get_kv_cache(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
tokenizer: TokenizerWrapper,
|
tokenizer: TokenizerProtocol,
|
||||||
sampler: Callable[[mx.array], mx.array],
|
sampler: Callable[[mx.array], mx.array],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
) -> list[_BaseCache]:
|
) -> tuple[list[KVCacheType], int]:
|
||||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
"""Get KV cache for prompt, reusing prefix if available.
|
||||||
max_length = len(tokenized_prompt)
|
|
||||||
|
|
||||||
best_snapshot_index, best_snapshot_length = None, 0
|
Args:
|
||||||
|
model: The model to create cache for.
|
||||||
|
tokenizer: Tokenizer for encoding prompt.
|
||||||
|
sampler: Sampler function for prefill.
|
||||||
|
prompt: The prompt string to process.
|
||||||
|
|
||||||
for i, cached_prompt in enumerate(self.prompts):
|
Returns:
|
||||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
Tuple of (kv_cache, tokens_reused). tokens_reused indicates how many
|
||||||
|
tokens were reused from cache (0 if no cache hit).
|
||||||
|
"""
|
||||||
|
tokens = self._encode_prompt(tokenizer, prompt)
|
||||||
|
target_len = int(tokens.shape[0])
|
||||||
|
|
||||||
if length == max_length:
|
# Find best prefix match
|
||||||
return self.caches[i]
|
best_key, prefix_len = self._find_best_prefix(tokens)
|
||||||
|
|
||||||
if length > best_snapshot_length:
|
if best_key is not None and prefix_len > 0:
|
||||||
best_snapshot_index, best_snapshot_length = i, length
|
cached_tokens, cached_kv = self._cache[best_key]
|
||||||
|
cached_len = int(cached_tokens.shape[0])
|
||||||
|
|
||||||
if best_snapshot_index is not None:
|
# Move to end (most recently used)
|
||||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
self._cache.move_to_end(best_key)
|
||||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
|
||||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
|
||||||
|
|
||||||
else:
|
if prefix_len == target_len and prefix_len == cached_len:
|
||||||
prompt_cache = make_kv_cache(
|
# Exact match - return deepcopy directly
|
||||||
model,
|
logger.debug(f"Prefix cache: exact match, reusing {prefix_len} tokens")
|
||||||
# max_kv_size=MAX_KV_SIZE,
|
return deepcopy(cached_kv), prefix_len
|
||||||
# keep=KEEP_KV_SIZE
|
|
||||||
|
# Partial match - need to trim and/or extend
|
||||||
|
prompt_cache = deepcopy(cached_kv)
|
||||||
|
|
||||||
|
if cached_len > prefix_len:
|
||||||
|
# Cached prompt is longer - trim to prefix length
|
||||||
|
num_to_trim = cached_len - prefix_len
|
||||||
|
trim_prompt_cache(prompt_cache, num_to_trim)
|
||||||
|
logger.debug(
|
||||||
|
f"Prefix cache: trimmed {num_to_trim} tokens from cached entry"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: We don't prefill remaining tokens here - stream_generate will do it
|
||||||
|
# when processing the full prompt with this partial cache
|
||||||
|
return prompt_cache, prefix_len
|
||||||
|
|
||||||
|
# No cache hit - return fresh cache (stream_generate will prefill)
|
||||||
|
logger.debug(
|
||||||
|
f"Prefix cache: miss, will prefill {target_len} tokens during generation"
|
||||||
|
)
|
||||||
|
prompt_cache = make_kv_cache(model=model)
|
||||||
|
|
||||||
|
return prompt_cache, 0
|
||||||
|
|
||||||
|
def put(
|
||||||
|
self, tokenizer: TokenizerProtocol, prompt: str, cache: list[KVCacheType]
|
||||||
|
) -> None:
|
||||||
|
"""Store KV cache for prompt after generation completes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: Tokenizer for encoding prompt.
|
||||||
|
prompt: The prompt string that was processed.
|
||||||
|
cache: The KV cache to store.
|
||||||
|
"""
|
||||||
|
tokens = self._encode_prompt(tokenizer, prompt)
|
||||||
|
key = self._token_key(tokens)
|
||||||
|
|
||||||
|
# If already in cache, just move to end
|
||||||
|
if key in self._cache:
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Evict LRU entry if at capacity
|
||||||
|
if len(self._cache) >= self.max_size:
|
||||||
|
evicted_key, _ = self._cache.popitem(last=False)
|
||||||
|
logger.debug(
|
||||||
|
f"Prefix cache: evicted LRU entry ({len(evicted_key)} token bytes)"
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
# Store deepcopy
|
||||||
|
self._cache[key] = (tokens, deepcopy(cache))
|
||||||
|
logger.debug(f"Prefix cache: stored entry with {tokens.shape[0]} tokens")
|
||||||
|
|
||||||
return prompt_cache
|
def clear(self) -> None:
|
||||||
|
"""Clear all cached entries."""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
def __len__(self) -> int:
|
||||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
"""Return number of cached entries."""
|
||||||
tokenizer.bos_token
|
return len(self._cache)
|
||||||
)
|
|
||||||
tokenized_prompt = tokenizer.encode(
|
|
||||||
prompt, add_special_tokens=add_special_tokens
|
|
||||||
)
|
|
||||||
return mx.array(tokenized_prompt)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
"""Calculate length of matching prefix between two token arrays."""
|
||||||
|
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||||
if n == 0:
|
if n == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||||
return int(mx.sum(prefix_mask).item())
|
return int(mx.sum(prefix_mask).item())
|
||||||
|
|
||||||
|
|
||||||
def prefill(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: TokenizerWrapper,
|
|
||||||
sampler: Callable[[mx.array], mx.array],
|
|
||||||
prompt: mx.array,
|
|
||||||
cache: list[_BaseCache],
|
|
||||||
) -> None:
|
|
||||||
for _ in stream_generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=0,
|
|
||||||
sampler=sampler,
|
|
||||||
prompt_cache=cache,
|
|
||||||
prefill_step_size=2048,
|
|
||||||
kv_group_size=KV_GROUP_SIZE,
|
|
||||||
kv_bits=KV_BITS,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from mlx_lm.models.cache import KVCache
|
|||||||
from mlx_lm.sample_utils import make_sampler
|
from mlx_lm.sample_utils import make_sampler
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
# from exo.engines.mlx.cache import KVPrefixCache
|
|
||||||
from exo.shared.types.api import (
|
from exo.shared.types.api import (
|
||||||
BenchChatCompletionTaskParams,
|
BenchChatCompletionTaskParams,
|
||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
@@ -19,6 +18,7 @@ from exo.shared.types.worker.runner_response import (
|
|||||||
GenerationResponse,
|
GenerationResponse,
|
||||||
)
|
)
|
||||||
from exo.worker.engines.mlx import Model
|
from exo.worker.engines.mlx import Model
|
||||||
|
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||||
from exo.worker.engines.mlx.utils_mlx import (
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
@@ -119,6 +119,7 @@ def mlx_generate(
|
|||||||
model: Model,
|
model: Model,
|
||||||
tokenizer: TokenizerWrapper,
|
tokenizer: TokenizerWrapper,
|
||||||
task: ChatCompletionTaskParams,
|
task: ChatCompletionTaskParams,
|
||||||
|
prefix_cache: KVPrefixCache | None = None,
|
||||||
) -> Generator[GenerationResponse]:
|
) -> Generator[GenerationResponse]:
|
||||||
# Ensure that generation stats only contains peak memory for this generation
|
# Ensure that generation stats only contains peak memory for this generation
|
||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
@@ -135,8 +136,6 @@ def mlx_generate(
|
|||||||
chat_task_data=task,
|
chat_task_data=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
caches = make_kv_cache(model=model)
|
|
||||||
|
|
||||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||||
if is_bench:
|
if is_bench:
|
||||||
# Only sample length eos tokens
|
# Only sample length eos tokens
|
||||||
@@ -148,6 +147,20 @@ def mlx_generate(
|
|||||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get KV cache - either from prefix cache or fresh
|
||||||
|
tokens_reused = 0
|
||||||
|
if prefix_cache is not None:
|
||||||
|
caches, tokens_reused = prefix_cache.get_kv_cache(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
sampler=sampler,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
if tokens_reused > 0:
|
||||||
|
logger.info(f"Prefix cache hit: reused {tokens_reused} tokens")
|
||||||
|
else:
|
||||||
|
caches = make_kv_cache(model=model)
|
||||||
|
|
||||||
max_tokens = task.max_tokens or MAX_TOKENS
|
max_tokens = task.max_tokens or MAX_TOKENS
|
||||||
for out in stream_generate(
|
for out in stream_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -189,6 +202,9 @@ def mlx_generate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if out.finish_reason is not None:
|
if out.finish_reason is not None:
|
||||||
|
# Store in prefix cache for future reuse
|
||||||
|
if prefix_cache is not None:
|
||||||
|
prefix_cache.put(tokenizer=tokenizer, prompt=prompt, cache=caches)
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO: Do we want an mx_barrier?
|
# TODO: Do we want an mx_barrier?
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from exo.shared.types.worker.runners import (
|
|||||||
RunnerWarmingUp,
|
RunnerWarmingUp,
|
||||||
)
|
)
|
||||||
from exo.utils.channels import MpReceiver, MpSender
|
from exo.utils.channels import MpReceiver, MpSender
|
||||||
|
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||||
from exo.worker.engines.mlx.utils_mlx import (
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
initialize_mlx,
|
initialize_mlx,
|
||||||
@@ -69,6 +70,7 @@ def main(
|
|||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
group = None
|
group = None
|
||||||
|
prefix_cache: KVPrefixCache | None = None
|
||||||
|
|
||||||
current_status: RunnerStatus = RunnerIdle()
|
current_status: RunnerStatus = RunnerIdle()
|
||||||
logger.info("runner created")
|
logger.info("runner created")
|
||||||
@@ -110,6 +112,8 @@ def main(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||||
|
prefix_cache = KVPrefixCache(max_size=10)
|
||||||
|
logger.info("prefix cache initialized")
|
||||||
|
|
||||||
current_status = RunnerLoaded()
|
current_status = RunnerLoaded()
|
||||||
logger.info("runner loaded")
|
logger.info("runner loaded")
|
||||||
@@ -157,6 +161,7 @@ def main(
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
task=task_params,
|
task=task_params,
|
||||||
|
prefix_cache=prefix_cache,
|
||||||
):
|
):
|
||||||
match response:
|
match response:
|
||||||
case GenerationResponse():
|
case GenerationResponse():
|
||||||
|
|||||||
141
src/exo/worker/tests/unittests/test_mlx/test_prefix_cache.py
Normal file
141
src/exo/worker/tests/unittests/test_mlx/test_prefix_cache.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Tests for KVPrefixCache."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from exo.worker.engines.mlx.cache import (
|
||||||
|
KVCacheType,
|
||||||
|
KVPrefixCache,
|
||||||
|
TokenizerProtocol,
|
||||||
|
get_prefix_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizer(TokenizerProtocol):
|
||||||
|
"""Mock tokenizer that converts string to list of char codes."""
|
||||||
|
|
||||||
|
bos_token: str | None = None
|
||||||
|
|
||||||
|
def encode(self, text: str, **kwargs: bool) -> list[int]:
|
||||||
|
"""Encode text to list of character codes."""
|
||||||
|
del kwargs # unused
|
||||||
|
return [ord(c) for c in text]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetPrefixLength:
|
||||||
|
"""Tests for the core prefix matching algorithm."""
|
||||||
|
|
||||||
|
def test_identical_arrays(self) -> None:
|
||||||
|
a = mx.array([1, 2, 3, 4, 5])
|
||||||
|
b = mx.array([1, 2, 3, 4, 5])
|
||||||
|
assert get_prefix_length(a, b) == 5
|
||||||
|
|
||||||
|
def test_partial_match(self) -> None:
|
||||||
|
a = mx.array([1, 2, 3, 4, 5])
|
||||||
|
b = mx.array([1, 2, 3, 9, 9])
|
||||||
|
assert get_prefix_length(a, b) == 3
|
||||||
|
|
||||||
|
def test_no_match(self) -> None:
|
||||||
|
a = mx.array([1, 2, 3])
|
||||||
|
b = mx.array([9, 9, 9])
|
||||||
|
assert get_prefix_length(a, b) == 0
|
||||||
|
|
||||||
|
def test_different_lengths(self) -> None:
|
||||||
|
short = mx.array([1, 2, 3])
|
||||||
|
long = mx.array([1, 2, 3, 4, 5])
|
||||||
|
# Should return length of shorter when they match
|
||||||
|
assert get_prefix_length(short, long) == 3
|
||||||
|
assert get_prefix_length(long, short) == 3
|
||||||
|
|
||||||
|
def test_empty_array(self) -> None:
|
||||||
|
empty: mx.array = mx.array([])
|
||||||
|
tokens = mx.array([1, 2, 3])
|
||||||
|
assert get_prefix_length(empty, tokens) == 0
|
||||||
|
assert get_prefix_length(tokens, empty) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestKVPrefixCache:
|
||||||
|
"""Tests for the KV prefix cache."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer(self) -> MockTokenizer:
|
||||||
|
"""Mock tokenizer that converts string to list of char codes."""
|
||||||
|
return MockTokenizer()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_kv(self) -> list[KVCacheType]:
|
||||||
|
"""Fake KV cache for testing."""
|
||||||
|
return [object()]
|
||||||
|
|
||||||
|
def test_put_stores_entry(
|
||||||
|
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||||
|
) -> None:
|
||||||
|
cache = KVPrefixCache(max_size=10)
|
||||||
|
|
||||||
|
cache.put(tokenizer, "hello", fake_kv)
|
||||||
|
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
def test_put_same_prompt_twice_does_not_duplicate(
|
||||||
|
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||||
|
) -> None:
|
||||||
|
cache = KVPrefixCache(max_size=10)
|
||||||
|
|
||||||
|
cache.put(tokenizer, "hello", fake_kv)
|
||||||
|
cache.put(tokenizer, "hello", fake_kv)
|
||||||
|
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
def test_lru_eviction(
|
||||||
|
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||||
|
) -> None:
|
||||||
|
cache = KVPrefixCache(max_size=2)
|
||||||
|
|
||||||
|
# Fill cache
|
||||||
|
cache.put(tokenizer, "first", fake_kv)
|
||||||
|
cache.put(tokenizer, "second", fake_kv)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
# Add third - should evict "first" (oldest)
|
||||||
|
cache.put(tokenizer, "third", fake_kv)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
# Add "first" again - if it was evicted, cache size stays 2
|
||||||
|
# If it wasn't evicted, this would be a no-op
|
||||||
|
cache.put(tokenizer, "first", fake_kv)
|
||||||
|
# Now add fourth - if "first" was re-added, size is still 2
|
||||||
|
cache.put(tokenizer, "fourth", fake_kv)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
def test_lru_access_refreshes_entry(
|
||||||
|
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||||
|
) -> None:
|
||||||
|
cache = KVPrefixCache(max_size=2)
|
||||||
|
|
||||||
|
# Add two entries
|
||||||
|
cache.put(tokenizer, "first", fake_kv)
|
||||||
|
cache.put(tokenizer, "second", fake_kv)
|
||||||
|
|
||||||
|
# Access "first" again (moves to end of LRU)
|
||||||
|
cache.put(tokenizer, "first", fake_kv)
|
||||||
|
|
||||||
|
# Add third - should evict "second" now (oldest)
|
||||||
|
cache.put(tokenizer, "third", fake_kv)
|
||||||
|
|
||||||
|
# Add "second" again - this will add it as new entry
|
||||||
|
cache.put(tokenizer, "second", fake_kv)
|
||||||
|
# Now "first" is oldest, adding fourth should evict it
|
||||||
|
cache.put(tokenizer, "fourth", fake_kv)
|
||||||
|
|
||||||
|
# Cache should have "second" and "fourth", not "first"
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
def test_clear(self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]) -> None:
|
||||||
|
cache = KVPrefixCache()
|
||||||
|
cache.put(tokenizer, "hello", fake_kv)
|
||||||
|
cache.put(tokenizer, "world", fake_kv)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
assert len(cache) == 0
|
||||||
Reference in New Issue
Block a user