""" Comprehensive unit tests for ThreadSafeLRUPromptCache. Tests all cache operation modes: - Exact match - Shorter prefix match - Longer prefix match (with trimming) - No match - LRU eviction - Reference counting - Multi-model namespacing - Thread safety with data integrity verification """ import unittest import concurrent.futures import threading import copy from mlx_cache import ThreadSafeLRUPromptCache class TestCacheExactMatch(unittest.TestCase): """Tests for exact match cache behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_exact_match_returns_cache_and_empty_remaining(self): """Exact match should return the cache with no remaining tokens.""" tokens = [1, 2, 3, 4, 5] mock_cache = ["kv_cache_data"] self.cache.insert_cache("model1", tokens, mock_cache) result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) self.assertEqual(result_cache, mock_cache) self.assertEqual(remaining, []) def test_exact_match_extracts_and_removes_from_cache(self): """Fetching exact match with count=1 should remove entry from cache.""" tokens = [1, 2, 3] self.cache.insert_cache("model1", tokens, ["cache"]) self.assertEqual(len(self.cache), 1) # First fetch extracts the entry self.cache.fetch_nearest_cache("model1", tokens) # Cache should now be empty self.assertEqual(len(self.cache), 0) # Second fetch should return None (no match) result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) self.assertIsNone(result_cache) self.assertEqual(remaining, tokens) class TestCacheShorterPrefix(unittest.TestCase): """Tests for shorter prefix match behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_shorter_prefix_returns_cache_with_remaining_tokens(self): """When cached prefix is shorter, return cache and remaining suffix.""" short_tokens = [1, 2, 3] long_tokens = [1, 2, 3, 4, 5, 6] mock_cache = ["prefix_cache"] self.cache.insert_cache("model1", short_tokens, mock_cache) result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens) self.assertEqual(result_cache, mock_cache) self.assertEqual(remaining, [4, 5, 6]) def test_shorter_prefix_correct_remaining_calculation(self): """Verify remaining tokens are calculated correctly for various prefix lengths.""" # Note: Single-token prefixes ([1] -> [1,2,3]) are deliberately not matched # to allow longer cached sequences to be preferred for trimming. # This matches upstream mlx_lm/server.py behavior. test_cases = [ # (cached_tokens, requested_tokens, expected_remaining) ([1, 2], [1, 2, 3, 4, 5], [3, 4, 5]), ([10, 20, 30, 40], [10, 20, 30, 40, 50], [50]), ] for cached, requested, expected_remaining in test_cases: with self.subTest(cached=cached, requested=requested): cache = ThreadSafeLRUPromptCache(max_size=10) cache.insert_cache("model", cached, ["cache"]) result_cache, remaining = cache.fetch_nearest_cache("model", requested) self.assertIsNotNone(result_cache) self.assertEqual(remaining, expected_remaining) def test_single_token_prefix_not_matched(self): """Single-token prefixes are not matched (by design, matches upstream). This allows longer cached sequences to be preferred for trimming, which provides better KV cache reuse. Single-token caches are rare in practice since real prompts with chat templates are many tokens. """ cache = ThreadSafeLRUPromptCache(max_size=10) cache.insert_cache("model", [1], ["cache"]) result_cache, remaining = cache.fetch_nearest_cache("model", [1, 2, 3]) # Single-token prefix is NOT matched self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 3]) class TestCacheLongerPrefix(unittest.TestCase): """Tests for longer prefix match behavior (trimming).""" def setUp(self): # Track trim calls for verification self.trim_calls = [] def mock_can_trim(cache): return True def mock_trim(cache, num_to_trim): self.trim_calls.append(num_to_trim) # Simulate trimming by modifying the cache cache.append(f"trimmed_{num_to_trim}") self.cache = ThreadSafeLRUPromptCache( max_size=10, can_trim_fn=mock_can_trim, trim_fn=mock_trim, ) def test_longer_prefix_triggers_trim(self): """When cached sequence is longer, should trim to match requested prefix.""" long_tokens = [1, 2, 3, 4, 5] short_tokens = [1, 2, 3] self.cache.insert_cache("model1", long_tokens, ["original_cache"]) result_cache, remaining = self.cache.fetch_nearest_cache("model1", short_tokens) # Should have called trim self.assertTrue(len(self.trim_calls) > 0, "trim_fn should have been called") # Result should be a trimmed copy, not the original self.assertIn("trimmed_", str(result_cache)) def test_longer_prefix_without_trim_fn_returns_no_match(self): """Without trim functions, longer prefix should not match.""" cache_no_trim = ThreadSafeLRUPromptCache(max_size=10) long_tokens = [1, 2, 3, 4, 5] short_tokens = [1, 2, 3] cache_no_trim.insert_cache("model1", long_tokens, ["cache"]) result_cache, remaining = cache_no_trim.fetch_nearest_cache("model1", short_tokens) # Without trim_fn, should return no match self.assertIsNone(result_cache) self.assertEqual(remaining, short_tokens) def test_longer_prefix_can_trim_false_returns_no_match(self): """When can_trim_fn returns False, should not attempt trim.""" cache = ThreadSafeLRUPromptCache( max_size=10, can_trim_fn=lambda c: False, trim_fn=lambda c, n: None, ) cache.insert_cache("model1", [1, 2, 3, 4, 5], ["cache"]) result_cache, remaining = cache.fetch_nearest_cache("model1", [1, 2, 3]) self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 3]) class TestCacheNoMatch(unittest.TestCase): """Tests for no match behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_empty_cache_returns_none(self): """Empty cache should return None and all tokens as remaining.""" tokens = [1, 2, 3] result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) self.assertIsNone(result_cache) self.assertEqual(remaining, tokens) def test_different_prefix_returns_none(self): """Tokens with different prefix should not match.""" self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) # Completely different tokens result_cache, remaining = self.cache.fetch_nearest_cache("model1", [4, 5, 6]) self.assertIsNone(result_cache) self.assertEqual(remaining, [4, 5, 6]) def test_partial_prefix_mismatch_returns_none(self): """Tokens that diverge mid-sequence should not match.""" self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) # Same start but diverges result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1, 2, 99]) self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 99]) def test_wrong_model_returns_none(self): """Different model key should not match.""" self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) result_cache, remaining = self.cache.fetch_nearest_cache("model2", [1, 2, 3]) self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 3]) class TestCacheLRUEviction(unittest.TestCase): """Tests for LRU eviction behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=3) def test_evicts_oldest_when_full(self): """Should evict least recently used entry when capacity exceeded.""" self.cache.insert_cache("model", [1], ["cache1"]) self.cache.insert_cache("model", [2], ["cache2"]) self.cache.insert_cache("model", [3], ["cache3"]) self.assertEqual(len(self.cache), 3) # Insert 4th entry - should evict [1] self.cache.insert_cache("model", [4], ["cache4"]) self.assertEqual(len(self.cache), 3) # [1] should be evicted result, _ = self.cache.fetch_nearest_cache("model", [1]) self.assertIsNone(result) # [2], [3], [4] should still exist for tokens in [[2], [3], [4]]: # Re-insert since fetch extracts self.cache.insert_cache("model", tokens, [f"cache{tokens[0]}"]) result2, _ = self.cache.fetch_nearest_cache("model", [2]) self.assertIsNotNone(result2) def test_access_updates_lru_order(self): """Accessing an entry should move it to most recently used.""" self.cache.insert_cache("model", [1], ["cache1"]) self.cache.insert_cache("model", [2], ["cache2"]) self.cache.insert_cache("model", [3], ["cache3"]) # Access [1] to make it most recently used cache1, _ = self.cache.fetch_nearest_cache("model", [1]) # Re-insert it (simulating normal usage pattern) self.cache.insert_cache("model", [1], cache1) # Now insert two more entries - should evict [2] then [3], not [1] self.cache.insert_cache("model", [4], ["cache4"]) self.cache.insert_cache("model", [5], ["cache5"]) # [1] should still exist (was accessed, so not evicted) result1, _ = self.cache.fetch_nearest_cache("model", [1]) self.assertIsNotNone(result1) # [2] should be evicted (was oldest after [1] was accessed) result2, _ = self.cache.fetch_nearest_cache("model", [2]) self.assertIsNone(result2) class TestCacheReferenceCount(unittest.TestCase): """Tests for reference counting behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_multiple_inserts_increment_count(self): """Inserting same tokens multiple times should increment count.""" tokens = [1, 2, 3] self.cache.insert_cache("model", tokens, ["cache"]) self.cache.insert_cache("model", tokens, ["cache"]) self.cache.insert_cache("model", tokens, ["cache"]) # Should still be one entry (with count=3 internally) self.assertEqual(len(self.cache), 1) # First two fetches should return copies (count decremented) result1, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNotNone(result1) result2, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNotNone(result2) # Third fetch extracts the last reference result3, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNotNone(result3) # Fourth fetch should return None (entry fully extracted) result4, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNone(result4) def test_extract_with_high_count_returns_deep_copy(self): """When count > 1, extract should return a deep copy.""" tokens = [1, 2, 3] original_cache = [{"nested": "data"}] self.cache.insert_cache("model", tokens, original_cache) self.cache.insert_cache("model", tokens, original_cache) # count=2 result1, _ = self.cache.fetch_nearest_cache("model", tokens) # Modify the returned cache result1[0]["nested"] = "modified" # Second fetch should get unmodified copy result2, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertEqual(result2[0]["nested"], "data") class TestCacheMultiModel(unittest.TestCase): """Tests for multi-model namespacing.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_same_tokens_different_models_are_separate(self): """Same token sequence under different models should be independent.""" tokens = [1, 2, 3] self.cache.insert_cache("model_a", tokens, ["cache_a"]) self.cache.insert_cache("model_b", tokens, ["cache_b"]) self.assertEqual(len(self.cache), 2) result_a, _ = self.cache.fetch_nearest_cache("model_a", tokens) result_b, _ = self.cache.fetch_nearest_cache("model_b", tokens) self.assertEqual(result_a, ["cache_a"]) self.assertEqual(result_b, ["cache_b"]) def test_eviction_across_models(self): """LRU eviction should work across different models.""" cache = ThreadSafeLRUPromptCache(max_size=3) cache.insert_cache("model_a", [1], ["a1"]) cache.insert_cache("model_b", [1], ["b1"]) cache.insert_cache("model_a", [2], ["a2"]) self.assertEqual(len(cache), 3) # Insert 4th - should evict model_a:[1] (oldest) cache.insert_cache("model_b", [2], ["b2"]) result, _ = cache.fetch_nearest_cache("model_a", [1]) self.assertIsNone(result) class TestCacheThreadSafety(unittest.TestCase): """Tests for thread safety with data integrity verification.""" def test_concurrent_inserts_no_data_loss(self): """Concurrent inserts should not lose data.""" cache = ThreadSafeLRUPromptCache(max_size=100) num_threads = 10 inserts_per_thread = 20 def insert_entries(thread_id): for i in range(inserts_per_thread): tokens = [thread_id, i] cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [executor.submit(insert_entries, tid) for tid in range(num_threads)] concurrent.futures.wait(futures) # Verify expected number of entries (may be less due to LRU eviction with max_size=100) # But should be exactly 100 since we inserted exactly 200 and max_size is 100 self.assertEqual(len(cache), 100) def test_concurrent_fetch_and_insert_no_corruption(self): """Concurrent fetches and inserts should not corrupt data.""" cache = ThreadSafeLRUPromptCache(max_size=50) errors = [] lock = threading.Lock() # Pre-populate with known data for i in range(20): cache.insert_cache("model", [i], [f"original_{i}"]) def fetch_and_verify(thread_id): try: for _ in range(50): token_id = thread_id % 20 result, remaining = cache.fetch_nearest_cache("model", [token_id]) if result is not None: # Verify data integrity expected_prefix = f"original_{token_id}" if not str(result[0]).startswith("original_"): with lock: errors.append(f"Corrupted data: {result}") # Re-insert to keep cache populated cache.insert_cache("model", [token_id], result) except Exception as e: with lock: errors.append(str(e)) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(fetch_and_verify, tid) for tid in range(10)] concurrent.futures.wait(futures) self.assertEqual(errors, [], f"Thread safety errors: {errors}") def test_concurrent_operations_maintain_cache_bounds(self): """Cache size should never exceed max_size under concurrent operations.""" max_size = 10 cache = ThreadSafeLRUPromptCache(max_size=max_size) size_violations = [] lock = threading.Lock() def random_operations(thread_id): import random for i in range(100): tokens = [random.randint(0, 50)] if random.random() < 0.7: cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) else: cache.fetch_nearest_cache("model", tokens) current_size = len(cache) if current_size > max_size: with lock: size_violations.append(current_size) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(random_operations, tid) for tid in range(10)] concurrent.futures.wait(futures) self.assertEqual(size_violations, [], f"Size exceeded max: {size_violations}") self.assertLessEqual(len(cache), max_size) class TestCacheClear(unittest.TestCase): """Tests for cache clear operation.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_clear_removes_all_entries(self): """Clear should remove all entries.""" self.cache.insert_cache("model1", [1, 2], ["cache1"]) self.cache.insert_cache("model2", [3, 4], ["cache2"]) self.cache.insert_cache("model1", [5, 6], ["cache3"]) self.assertEqual(len(self.cache), 3) self.cache.clear() self.assertEqual(len(self.cache), 0) def test_clear_allows_new_inserts(self): """After clear, new inserts should work normally.""" self.cache.insert_cache("model", [1], ["cache1"]) self.cache.clear() self.cache.insert_cache("model", [2], ["cache2"]) self.assertEqual(len(self.cache), 1) result, _ = self.cache.fetch_nearest_cache("model", [2]) self.assertEqual(result, ["cache2"]) if __name__ == "__main__": unittest.main()