diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index 072f8a0b0..aaa0d6f34 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -14,11 +14,13 @@ import backend_pb2_grpc import grpc from mlx_lm import load, generate, stream_generate from mlx_lm.sample_utils import make_sampler -from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache import mlx.core as mx import base64 import io +from mlx_cache import ThreadSafeLRUPromptCache + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 @@ -118,10 +120,16 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) else: self.model, self.tokenizer = load(request.Model) - - # Initialize prompt cache for efficient generation - max_kv_size = self.options.get("max_kv_size", None) - self.prompt_cache = make_prompt_cache(self.model, max_kv_size) + + # Initialize thread-safe LRU prompt cache for efficient generation + max_cache_entries = self.options.get("max_cache_entries", 10) + self.max_kv_size = self.options.get("max_kv_size", None) + self.model_key = request.Model + self.lru_cache = ThreadSafeLRUPromptCache( + max_size=max_cache_entries, + can_trim_fn=can_trim_prompt_cache, + trim_fn=trim_prompt_cache, + ) except Exception as err: print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr) @@ -134,6 +142,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): """ Generates text based on the given prompt and sampling parameters using MLX. + Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. + Args: request: The predict request. context: The gRPC context. @@ -141,31 +151,48 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Reply: The predict result. """ + prompt_cache = None + cache_key = None + try: - # Prepare the prompt - prompt = self._prepare_prompt(request) - + # Prepare the prompt and tokenize for cache key + prompt_text = self._prepare_prompt(request) + cache_key = self._get_tokens_from_prompt(prompt_text) + + # Fetch nearest cache (exact, shorter prefix, or create new) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + # Build generation parameters using request attributes and options max_tokens, sampler_params = self._build_generation_params(request) - - print(f"Generating text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr) - + + print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) + # Create sampler with parameters sampler = make_sampler(**sampler_params) - - # Generate text using MLX with proper parameters - response = generate( + + # Use stream_generate to track generated tokens for cache key + generated_text = [] + for response in stream_generate( self.model, self.tokenizer, - prompt=prompt, + prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, - prompt_cache=self.prompt_cache, - verbose=False - ) - - return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) - + prompt_cache=prompt_cache, + ): + generated_text.append(response.text) + cache_key.append(response.token) + + # Insert completed cache + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) + + return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8')) + except Exception as e: print(f"Error in MLX Predict: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) @@ -194,6 +221,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): """ Generates text based on the given prompt and sampling parameters, and streams the results using MLX. + Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. + Args: request: The predict stream request. context: The gRPC context. @@ -201,35 +230,56 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Yields: backend_pb2.Reply: Streaming predict results. """ + prompt_cache = None + cache_key = None + try: - # Prepare the prompt - prompt = self._prepare_prompt(request) - + # Prepare the prompt and tokenize for cache key + prompt_text = self._prepare_prompt(request) + cache_key = self._get_tokens_from_prompt(prompt_text) + + # Fetch nearest cache (exact, shorter prefix, or create new) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + # Build generation parameters using request attributes and options max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) - - print(f"Streaming text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr) - + + print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) + # Create sampler with parameters sampler = make_sampler(**sampler_params) - + # Stream text generation using MLX with proper parameters for response in stream_generate( self.model, self.tokenizer, - prompt=prompt, + prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, - prompt_cache=self.prompt_cache, + prompt_cache=prompt_cache, ): + cache_key.append(response.token) yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) - + except Exception as e: print(f"Error in MLX PredictStream: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Streaming generation failed: {str(e)}") yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + finally: + # Always insert cache, even on interruption + if prompt_cache is not None and cache_key is not None: + try: + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) + except Exception as e: + print(f"Error inserting cache: {e}", file=sys.stderr) + def _prepare_prompt(self, request): """ Prepare the prompt for MLX generation, handling chat templates if needed. @@ -246,16 +296,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): messages = [] for msg in request.Messages: messages.append({"role": msg.role, "content": msg.content}) - + prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, + messages, + tokenize=False, add_generation_prompt=True ) return prompt else: return request.Prompt + def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: + """ + Tokenize prompt text for cache key generation. + + Args: + prompt_text: The prompt string to tokenize. + + Returns: + List[int]: List of token IDs. + """ + tokens = self.tokenizer.encode(prompt_text) + if hasattr(tokens, 'tolist'): + return tokens.tolist() + return list(tokens) + @@ -284,11 +349,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): top_p = getattr(request, 'TopP', 0.0) if top_p == 0.0: top_p = 1.0 # Default top_p - + + min_p = getattr(request, 'MinP', 0.0) + # min_p default of 0.0 means disabled (no filtering) + + top_k = getattr(request, 'TopK', 0) + # top_k default of 0 means disabled (no filtering) + # Initialize sampler parameters sampler_params = { 'temp': temp, 'top_p': top_p, + 'min_p': min_p, + 'top_k': top_k, 'xtc_threshold': 0.0, 'xtc_probability': 0.0, } @@ -308,7 +381,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): sampler_option_mapping = { 'temp': 'temp', 'temperature': 'temp', # alias - 'top_p': 'top_p', + 'top_p': 'top_p', + 'min_p': 'min_p', + 'top_k': 'top_k', 'xtc_threshold': 'xtc_threshold', 'xtc_probability': 'xtc_probability', } diff --git a/backend/python/mlx/mlx_cache.py b/backend/python/mlx/mlx_cache.py new file mode 100644 index 000000000..6ec2bb9ba --- /dev/null +++ b/backend/python/mlx/mlx_cache.py @@ -0,0 +1,266 @@ +""" +Thread-safe LRU prompt cache for MLX-based backends. + +Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.) +with thread-safety additions for LocalAI's gRPC backend. + +Usage: + from mlx_cache import ThreadSafeLRUPromptCache + + # In LoadModel: + self.lru_cache = ThreadSafeLRUPromptCache(max_size=10) + + # In Predict/PredictStream: + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens) + # ... generate ... + self.lru_cache.insert_cache(model_key, tokens, prompt_cache) +""" +import copy +import threading +from collections import deque +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + + +@dataclass +class CacheEntry: + """A cache entry with reference counting.""" + prompt_cache: List[Any] + count: int + + +@dataclass +class SearchResult: + """Result of searching the cache trie.""" + model: Any + exact: Optional[List[int]] + shorter: Optional[List[int]] + longer: Optional[List[int]] + common_prefix: int + + +class ThreadSafeLRUPromptCache: + """ + Thread-safe LRU cache with prefix matching for prompt KV caches. + + This cache stores KV caches keyed by token sequences and supports: + - Exact match: Return the cache for the exact token sequence + - Shorter prefix match: Return a cache for a prefix of the tokens + - Longer prefix match: If a longer sequence is cached and can be trimmed + - LRU eviction: When max_size is exceeded, evict least recently used + + Thread safety is provided via a threading.Lock that protects all + cache operations. + + Args: + max_size: Maximum number of cache entries (default: 10) + can_trim_fn: Optional function to check if a cache can be trimmed + trim_fn: Optional function to trim a cache + """ + + def __init__( + self, + max_size: int = 10, + can_trim_fn: Optional[Any] = None, + trim_fn: Optional[Any] = None, + ): + self.max_size = max_size + self._cache = {} + self._lru = deque() + self._lock = threading.Lock() + + # Optional trim functions (for longer prefix reuse) + self._can_trim_fn = can_trim_fn + self._trim_fn = trim_fn + + def _search(self, model, tokens: List[int]) -> SearchResult: + """ + Search the cache for a prompt cache. Return exact or close match. + + The cache is organized as a trie where each node is keyed by a token. + This allows efficient prefix matching. + """ + if model not in self._cache: + return SearchResult(model, None, None, None, 0) + + current = self._cache[model] + last_cache_index = -1 + index = 0 + + # Traverse the trie following the token sequence + while index < len(tokens) and tokens[index] in current: + current = current[tokens[index]] + if "cache" in current: + last_cache_index = index + index += 1 + + # Exact match - no need to search for longer or shorter caches + if last_cache_index == len(tokens) - 1: + return SearchResult(model, tuple(tokens), None, None, 0) + + # Find the shorter cache (a prefix that has a cache) + # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior. + # Single-token prefixes are not matched, which allows longer cached + # sequences to be preferred for trimming. This is acceptable because + # real prompts with chat templates are always many tokens. + shorter = None + if last_cache_index > 0: + shorter = tuple(tokens[: last_cache_index + 1]) + + # Check for caches that are longer than our token sequence + longer = None + common_prefix = index + if index > 0 and last_cache_index <= 0: + best = None + stack = [(current, [])] + while stack: + current, extra = stack.pop() + if "cache" in current: + if best is None or len(extra) < len(best): + best = extra + else: + for tok in current: + stack.append((current[tok], extra + [tok])) + if best is not None: + longer = tuple(tokens[:index] + best) + + return SearchResult(model, None, shorter, longer, common_prefix) + + def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """Get a cache entry by traversing the trie.""" + current = self._cache[model] + for tok in tokens: + current = current[tok] + return current["cache"] + + def _delete(self, model, tokens: Tuple[int, ...]) -> None: + """Delete a cache entry and clean up empty trie nodes.""" + path = [self._cache[model]] + for tok in tokens: + path.append(path[-1][tok]) + del path[-1]["cache"] + + # Clean up empty nodes bottom-up + for i in reversed(range(len(tokens))): + d_prev, d, t = path[i], path[i + 1], tokens[i] + if len(d) > 0: + break + del d_prev[t] + + def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """ + Extract a cache entry for exclusive use. + + If the entry has count > 1, deep copy and decrement. + If count == 1, remove from cache entirely. + """ + cache_entry = self._get(model, tokens) + if cache_entry.count == 1: + self._delete(model, tokens) + self._lru.remove((model, tokens)) + return cache_entry + + cache_entry.count -= 1 + return CacheEntry( + copy.deepcopy(cache_entry.prompt_cache), + 1, + ) + + def fetch_nearest_cache( + self, model, tokens: List[int] + ) -> Tuple[Optional[List[Any]], List[int]]: + """ + Fetch the nearest cache for the given token sequence. + + Thread-safe. Returns (cache, remaining_tokens) where: + - cache: The KV cache to use (or None if no cache found) + - remaining_tokens: Tokens that still need to be processed + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence for the prompt + + Returns: + Tuple of (prompt_cache, remaining_tokens) + """ + with self._lock: + tokens_tuple = tuple(tokens) + result = self._search(model, tokens) + + # Exact match - extract and return + if result.exact is not None: + cache_entry = self._extract(result.model, result.exact) + return cache_entry.prompt_cache, [] + + # Shorter prefix match - extract and return remaining + if result.shorter is not None: + cache_entry = self._extract(result.model, result.shorter) + prefix_len = len(result.shorter) + return cache_entry.prompt_cache, list(tokens[prefix_len:]) + + # Longer prefix match - try to trim if possible + if result.longer is not None and self._can_trim_fn is not None: + cache_entry = self._get(result.model, result.longer) + if self._can_trim_fn(cache_entry.prompt_cache): + # Deep copy and trim + trimmed_cache = copy.deepcopy(cache_entry.prompt_cache) + prefix = min(len(tokens) - 1, result.common_prefix) + num_to_trim = len(result.longer) - prefix + if self._trim_fn is not None: + self._trim_fn(trimmed_cache, num_to_trim) + return trimmed_cache, list(tokens[prefix:]) + + # No match found + return None, list(tokens) + + def insert_cache( + self, model, tokens: List[int], prompt_cache: List[Any] + ) -> None: + """ + Insert a cache entry after generation completes. + + Thread-safe. Handles LRU eviction if max_size is exceeded. + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence (prompt + generated) + prompt_cache: The KV cache to store + """ + with self._lock: + tokens_tuple = tuple(tokens) + + if model not in self._cache: + self._cache[model] = {} + current = self._cache[model] + + # Build trie path + for tok in tokens_tuple: + if tok not in current: + current[tok] = {} + current = current[tok] + + # Update or create entry + if "cache" in current: + current["cache"].count += 1 + self._lru.remove((model, tokens_tuple)) + else: + current["cache"] = CacheEntry(prompt_cache, 1) + + # Update LRU order + self._lru.append((model, tokens_tuple)) + + # Evict if over capacity + if len(self._lru) > self.max_size: + evict_model, evict_tokens = self._lru.popleft() + self._delete(evict_model, evict_tokens) + + def clear(self) -> None: + """Clear all cache entries. Thread-safe.""" + with self._lock: + self._cache.clear() + self._lru.clear() + + def __len__(self) -> int: + """Return the number of cache entries. Thread-safe.""" + with self._lock: + return len(self._lru) diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py index 827aa71a3..53d7bc7ec 100644 --- a/backend/python/mlx/test.py +++ b/backend/python/mlx/test.py @@ -1,17 +1,10 @@ import unittest import subprocess import time -import backend_pb2 -import backend_pb2_grpc import grpc - -import unittest -import subprocess -import time -import grpc -import backend_pb2_grpc import backend_pb2 +import backend_pb2_grpc class TestBackendServicer(unittest.TestCase): """ @@ -47,9 +40,9 @@ class TestBackendServicer(unittest.TestCase): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) - self.assertEqual(response.message, "Model loaded successfully") + self.assertEqual(response.message, "MLX model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") @@ -64,7 +57,7 @@ class TestBackendServicer(unittest.TestCase): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) @@ -84,7 +77,7 @@ class TestBackendServicer(unittest.TestCase): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions( @@ -95,26 +88,13 @@ class TestBackendServicer(unittest.TestCase): TopK=40, PresencePenalty=0.1, FrequencyPenalty=0.2, - RepetitionPenalty=1.1, MinP=0.05, Seed=42, StopPrompts=["\n"], - StopTokenIds=[50256], - BadWords=["badword"], - IncludeStopStrInOutput=True, IgnoreEOS=True, - MinTokens=5, - Logprobs=5, - PromptLogprobs=5, - SkipSpecialTokens=True, - SpacesBetweenSpecialTokens=True, - TruncatePromptTokens=10, - GuidedDecoding=True, - N=2, ) resp = stub.Predict(req) self.assertIsNotNone(resp.message) - self.assertIsNotNone(resp.logprobs) except Exception as err: print(err) self.fail("sampling params service failed") @@ -143,4 +123,112 @@ class TestBackendServicer(unittest.TestCase): print(err) self.fail("Embedding service failed") finally: - self.tearDown() \ No newline at end of file + self.tearDown() + + def test_concurrent_requests(self): + """ + This method tests that concurrent requests don't corrupt each other's cache state. + This is a regression test for the race condition in the original implementation. + """ + import concurrent.futures + + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + + def make_request(prompt): + req = backend_pb2.PredictOptions(Prompt=prompt, Tokens=20) + return stub.Predict(req) + + # Run 5 concurrent requests with different prompts + prompts = [ + "The capital of France is", + "The capital of Germany is", + "The capital of Italy is", + "The capital of Spain is", + "The capital of Portugal is", + ] + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(make_request, p) for p in prompts] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + # All results should be non-empty + messages = [r.message for r in results] + self.assertTrue(all(len(m) > 0 for m in messages), "All requests should return non-empty responses") + print(f"Concurrent test passed: {len(messages)} responses received") + + except Exception as err: + print(err) + self.fail("Concurrent requests test failed") + finally: + self.tearDown() + + def test_cache_reuse(self): + """ + This method tests that repeated prompts reuse cached KV states. + The second request should benefit from the cached prompt processing. + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + + prompt = "The quick brown fox jumps over the lazy dog. " + + # First request - populates cache + req1 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10) + resp1 = stub.Predict(req1) + self.assertIsNotNone(resp1.message) + + # Second request with same prompt - should reuse cache + req2 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10) + resp2 = stub.Predict(req2) + self.assertIsNotNone(resp2.message) + + print(f"Cache reuse test passed: first={len(resp1.message)} bytes, second={len(resp2.message)} bytes") + + except Exception as err: + print(err) + self.fail("Cache reuse test failed") + finally: + self.tearDown() + + def test_prefix_cache_reuse(self): + """ + This method tests that prompts sharing a common prefix benefit from cached KV states. + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + + # First request with base prompt + prompt_base = "Once upon a time in a land far away, " + req1 = backend_pb2.PredictOptions(Prompt=prompt_base, Tokens=10) + resp1 = stub.Predict(req1) + self.assertIsNotNone(resp1.message) + + # Second request with extended prompt (same prefix) + prompt_extended = prompt_base + "there lived a brave knight who " + req2 = backend_pb2.PredictOptions(Prompt=prompt_extended, Tokens=10) + resp2 = stub.Predict(req2) + self.assertIsNotNone(resp2.message) + + print(f"Prefix cache test passed: base={len(resp1.message)} bytes, extended={len(resp2.message)} bytes") + + except Exception as err: + print(err) + self.fail("Prefix cache reuse test failed") + finally: + self.tearDown() + + +# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py \ No newline at end of file diff --git a/backend/python/mlx/test_mlx_cache.py b/backend/python/mlx/test_mlx_cache.py new file mode 100644 index 000000000..c888782e9 --- /dev/null +++ b/backend/python/mlx/test_mlx_cache.py @@ -0,0 +1,480 @@ +""" +Comprehensive unit tests for ThreadSafeLRUPromptCache. + +Tests all cache operation modes: +- Exact match +- Shorter prefix match +- Longer prefix match (with trimming) +- No match +- LRU eviction +- Reference counting +- Multi-model namespacing +- Thread safety with data integrity verification +""" +import unittest +import concurrent.futures +import threading +import copy +from mlx_cache import ThreadSafeLRUPromptCache + + +class TestCacheExactMatch(unittest.TestCase): + """Tests for exact match cache behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_exact_match_returns_cache_and_empty_remaining(self): + """Exact match should return the cache with no remaining tokens.""" + tokens = [1, 2, 3, 4, 5] + mock_cache = ["kv_cache_data"] + + self.cache.insert_cache("model1", tokens, mock_cache) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) + + self.assertEqual(result_cache, mock_cache) + self.assertEqual(remaining, []) + + def test_exact_match_extracts_and_removes_from_cache(self): + """Fetching exact match with count=1 should remove entry from cache.""" + tokens = [1, 2, 3] + self.cache.insert_cache("model1", tokens, ["cache"]) + + self.assertEqual(len(self.cache), 1) + + # First fetch extracts the entry + self.cache.fetch_nearest_cache("model1", tokens) + + # Cache should now be empty + self.assertEqual(len(self.cache), 0) + + # Second fetch should return None (no match) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) + self.assertIsNone(result_cache) + self.assertEqual(remaining, tokens) + + +class TestCacheShorterPrefix(unittest.TestCase): + """Tests for shorter prefix match behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_shorter_prefix_returns_cache_with_remaining_tokens(self): + """When cached prefix is shorter, return cache and remaining suffix.""" + short_tokens = [1, 2, 3] + long_tokens = [1, 2, 3, 4, 5, 6] + mock_cache = ["prefix_cache"] + + self.cache.insert_cache("model1", short_tokens, mock_cache) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens) + + self.assertEqual(result_cache, mock_cache) + self.assertEqual(remaining, [4, 5, 6]) + + def test_shorter_prefix_correct_remaining_calculation(self): + """Verify remaining tokens are calculated correctly for various prefix lengths.""" + # Note: Single-token prefixes ([1] -> [1,2,3]) are deliberately not matched + # to allow longer cached sequences to be preferred for trimming. + # This matches upstream mlx_lm/server.py behavior. + test_cases = [ + # (cached_tokens, requested_tokens, expected_remaining) + ([1, 2], [1, 2, 3, 4, 5], [3, 4, 5]), + ([10, 20, 30, 40], [10, 20, 30, 40, 50], [50]), + ] + + for cached, requested, expected_remaining in test_cases: + with self.subTest(cached=cached, requested=requested): + cache = ThreadSafeLRUPromptCache(max_size=10) + cache.insert_cache("model", cached, ["cache"]) + result_cache, remaining = cache.fetch_nearest_cache("model", requested) + + self.assertIsNotNone(result_cache) + self.assertEqual(remaining, expected_remaining) + + def test_single_token_prefix_not_matched(self): + """Single-token prefixes are not matched (by design, matches upstream). + + This allows longer cached sequences to be preferred for trimming, + which provides better KV cache reuse. Single-token caches are rare + in practice since real prompts with chat templates are many tokens. + """ + cache = ThreadSafeLRUPromptCache(max_size=10) + cache.insert_cache("model", [1], ["cache"]) + + result_cache, remaining = cache.fetch_nearest_cache("model", [1, 2, 3]) + + # Single-token prefix is NOT matched + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 3]) + + +class TestCacheLongerPrefix(unittest.TestCase): + """Tests for longer prefix match behavior (trimming).""" + + def setUp(self): + # Track trim calls for verification + self.trim_calls = [] + + def mock_can_trim(cache): + return True + + def mock_trim(cache, num_to_trim): + self.trim_calls.append(num_to_trim) + # Simulate trimming by modifying the cache + cache.append(f"trimmed_{num_to_trim}") + + self.cache = ThreadSafeLRUPromptCache( + max_size=10, + can_trim_fn=mock_can_trim, + trim_fn=mock_trim, + ) + + def test_longer_prefix_triggers_trim(self): + """When cached sequence is longer, should trim to match requested prefix.""" + long_tokens = [1, 2, 3, 4, 5] + short_tokens = [1, 2, 3] + + self.cache.insert_cache("model1", long_tokens, ["original_cache"]) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", short_tokens) + + # Should have called trim + self.assertTrue(len(self.trim_calls) > 0, "trim_fn should have been called") + # Result should be a trimmed copy, not the original + self.assertIn("trimmed_", str(result_cache)) + + def test_longer_prefix_without_trim_fn_returns_no_match(self): + """Without trim functions, longer prefix should not match.""" + cache_no_trim = ThreadSafeLRUPromptCache(max_size=10) + + long_tokens = [1, 2, 3, 4, 5] + short_tokens = [1, 2, 3] + + cache_no_trim.insert_cache("model1", long_tokens, ["cache"]) + result_cache, remaining = cache_no_trim.fetch_nearest_cache("model1", short_tokens) + + # Without trim_fn, should return no match + self.assertIsNone(result_cache) + self.assertEqual(remaining, short_tokens) + + def test_longer_prefix_can_trim_false_returns_no_match(self): + """When can_trim_fn returns False, should not attempt trim.""" + cache = ThreadSafeLRUPromptCache( + max_size=10, + can_trim_fn=lambda c: False, + trim_fn=lambda c, n: None, + ) + + cache.insert_cache("model1", [1, 2, 3, 4, 5], ["cache"]) + result_cache, remaining = cache.fetch_nearest_cache("model1", [1, 2, 3]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 3]) + + +class TestCacheNoMatch(unittest.TestCase): + """Tests for no match behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_empty_cache_returns_none(self): + """Empty cache should return None and all tokens as remaining.""" + tokens = [1, 2, 3] + result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, tokens) + + def test_different_prefix_returns_none(self): + """Tokens with different prefix should not match.""" + self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) + + # Completely different tokens + result_cache, remaining = self.cache.fetch_nearest_cache("model1", [4, 5, 6]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [4, 5, 6]) + + def test_partial_prefix_mismatch_returns_none(self): + """Tokens that diverge mid-sequence should not match.""" + self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) + + # Same start but diverges + result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1, 2, 99]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 99]) + + def test_wrong_model_returns_none(self): + """Different model key should not match.""" + self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) + + result_cache, remaining = self.cache.fetch_nearest_cache("model2", [1, 2, 3]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 3]) + + +class TestCacheLRUEviction(unittest.TestCase): + """Tests for LRU eviction behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=3) + + def test_evicts_oldest_when_full(self): + """Should evict least recently used entry when capacity exceeded.""" + self.cache.insert_cache("model", [1], ["cache1"]) + self.cache.insert_cache("model", [2], ["cache2"]) + self.cache.insert_cache("model", [3], ["cache3"]) + + self.assertEqual(len(self.cache), 3) + + # Insert 4th entry - should evict [1] + self.cache.insert_cache("model", [4], ["cache4"]) + + self.assertEqual(len(self.cache), 3) + + # [1] should be evicted + result, _ = self.cache.fetch_nearest_cache("model", [1]) + self.assertIsNone(result) + + # [2], [3], [4] should still exist + for tokens in [[2], [3], [4]]: + # Re-insert since fetch extracts + self.cache.insert_cache("model", tokens, [f"cache{tokens[0]}"]) + + result2, _ = self.cache.fetch_nearest_cache("model", [2]) + self.assertIsNotNone(result2) + + def test_access_updates_lru_order(self): + """Accessing an entry should move it to most recently used.""" + self.cache.insert_cache("model", [1], ["cache1"]) + self.cache.insert_cache("model", [2], ["cache2"]) + self.cache.insert_cache("model", [3], ["cache3"]) + + # Access [1] to make it most recently used + cache1, _ = self.cache.fetch_nearest_cache("model", [1]) + # Re-insert it (simulating normal usage pattern) + self.cache.insert_cache("model", [1], cache1) + + # Now insert two more entries - should evict [2] then [3], not [1] + self.cache.insert_cache("model", [4], ["cache4"]) + self.cache.insert_cache("model", [5], ["cache5"]) + + # [1] should still exist (was accessed, so not evicted) + result1, _ = self.cache.fetch_nearest_cache("model", [1]) + self.assertIsNotNone(result1) + + # [2] should be evicted (was oldest after [1] was accessed) + result2, _ = self.cache.fetch_nearest_cache("model", [2]) + self.assertIsNone(result2) + + +class TestCacheReferenceCount(unittest.TestCase): + """Tests for reference counting behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_multiple_inserts_increment_count(self): + """Inserting same tokens multiple times should increment count.""" + tokens = [1, 2, 3] + + self.cache.insert_cache("model", tokens, ["cache"]) + self.cache.insert_cache("model", tokens, ["cache"]) + self.cache.insert_cache("model", tokens, ["cache"]) + + # Should still be one entry (with count=3 internally) + self.assertEqual(len(self.cache), 1) + + # First two fetches should return copies (count decremented) + result1, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNotNone(result1) + + result2, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNotNone(result2) + + # Third fetch extracts the last reference + result3, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNotNone(result3) + + # Fourth fetch should return None (entry fully extracted) + result4, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNone(result4) + + def test_extract_with_high_count_returns_deep_copy(self): + """When count > 1, extract should return a deep copy.""" + tokens = [1, 2, 3] + original_cache = [{"nested": "data"}] + + self.cache.insert_cache("model", tokens, original_cache) + self.cache.insert_cache("model", tokens, original_cache) # count=2 + + result1, _ = self.cache.fetch_nearest_cache("model", tokens) + + # Modify the returned cache + result1[0]["nested"] = "modified" + + # Second fetch should get unmodified copy + result2, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertEqual(result2[0]["nested"], "data") + + +class TestCacheMultiModel(unittest.TestCase): + """Tests for multi-model namespacing.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_same_tokens_different_models_are_separate(self): + """Same token sequence under different models should be independent.""" + tokens = [1, 2, 3] + + self.cache.insert_cache("model_a", tokens, ["cache_a"]) + self.cache.insert_cache("model_b", tokens, ["cache_b"]) + + self.assertEqual(len(self.cache), 2) + + result_a, _ = self.cache.fetch_nearest_cache("model_a", tokens) + result_b, _ = self.cache.fetch_nearest_cache("model_b", tokens) + + self.assertEqual(result_a, ["cache_a"]) + self.assertEqual(result_b, ["cache_b"]) + + def test_eviction_across_models(self): + """LRU eviction should work across different models.""" + cache = ThreadSafeLRUPromptCache(max_size=3) + + cache.insert_cache("model_a", [1], ["a1"]) + cache.insert_cache("model_b", [1], ["b1"]) + cache.insert_cache("model_a", [2], ["a2"]) + + self.assertEqual(len(cache), 3) + + # Insert 4th - should evict model_a:[1] (oldest) + cache.insert_cache("model_b", [2], ["b2"]) + + result, _ = cache.fetch_nearest_cache("model_a", [1]) + self.assertIsNone(result) + + +class TestCacheThreadSafety(unittest.TestCase): + """Tests for thread safety with data integrity verification.""" + + def test_concurrent_inserts_no_data_loss(self): + """Concurrent inserts should not lose data.""" + cache = ThreadSafeLRUPromptCache(max_size=100) + num_threads = 10 + inserts_per_thread = 20 + + def insert_entries(thread_id): + for i in range(inserts_per_thread): + tokens = [thread_id, i] + cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(insert_entries, tid) for tid in range(num_threads)] + concurrent.futures.wait(futures) + + # Verify expected number of entries (may be less due to LRU eviction with max_size=100) + # But should be exactly 100 since we inserted exactly 200 and max_size is 100 + self.assertEqual(len(cache), 100) + + def test_concurrent_fetch_and_insert_no_corruption(self): + """Concurrent fetches and inserts should not corrupt data.""" + cache = ThreadSafeLRUPromptCache(max_size=50) + errors = [] + lock = threading.Lock() + + # Pre-populate with known data + for i in range(20): + cache.insert_cache("model", [i], [f"original_{i}"]) + + def fetch_and_verify(thread_id): + try: + for _ in range(50): + token_id = thread_id % 20 + result, remaining = cache.fetch_nearest_cache("model", [token_id]) + + if result is not None: + # Verify data integrity + expected_prefix = f"original_{token_id}" + if not str(result[0]).startswith("original_"): + with lock: + errors.append(f"Corrupted data: {result}") + + # Re-insert to keep cache populated + cache.insert_cache("model", [token_id], result) + + except Exception as e: + with lock: + errors.append(str(e)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(fetch_and_verify, tid) for tid in range(10)] + concurrent.futures.wait(futures) + + self.assertEqual(errors, [], f"Thread safety errors: {errors}") + + def test_concurrent_operations_maintain_cache_bounds(self): + """Cache size should never exceed max_size under concurrent operations.""" + max_size = 10 + cache = ThreadSafeLRUPromptCache(max_size=max_size) + size_violations = [] + lock = threading.Lock() + + def random_operations(thread_id): + import random + for i in range(100): + tokens = [random.randint(0, 50)] + if random.random() < 0.7: + cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) + else: + cache.fetch_nearest_cache("model", tokens) + + current_size = len(cache) + if current_size > max_size: + with lock: + size_violations.append(current_size) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(random_operations, tid) for tid in range(10)] + concurrent.futures.wait(futures) + + self.assertEqual(size_violations, [], f"Size exceeded max: {size_violations}") + self.assertLessEqual(len(cache), max_size) + + +class TestCacheClear(unittest.TestCase): + """Tests for cache clear operation.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_clear_removes_all_entries(self): + """Clear should remove all entries.""" + self.cache.insert_cache("model1", [1, 2], ["cache1"]) + self.cache.insert_cache("model2", [3, 4], ["cache2"]) + self.cache.insert_cache("model1", [5, 6], ["cache3"]) + + self.assertEqual(len(self.cache), 3) + + self.cache.clear() + + self.assertEqual(len(self.cache), 0) + + def test_clear_allows_new_inserts(self): + """After clear, new inserts should work normally.""" + self.cache.insert_cache("model", [1], ["cache1"]) + self.cache.clear() + self.cache.insert_cache("model", [2], ["cache2"]) + + self.assertEqual(len(self.cache), 1) + + result, _ = self.cache.fetch_nearest_cache("model", [2]) + self.assertEqual(result, ["cache2"]) + + +if __name__ == "__main__": + unittest.main()