Files
LocalAI/backend/python/mlx/test.py
blightbow 67baf66555 feat(mlx): add thread-safe LRU prompt cache and min_p/top_k sampling (#7556)
* feat(mlx): add thread-safe LRU prompt cache

Port mlx-lm's LRUPromptCache to fix race condition where concurrent
requests corrupt shared KV cache state. The previous implementation
used a single prompt_cache instance shared across all requests.

Changes:
- Add backend/python/common/mlx_cache.py with ThreadSafeLRUPromptCache
- Modify backend.py to use per-request cache isolation via fetch/insert
- Add prefix matching for cache reuse across similar prompts
- Add LRU eviction (default 10 entries, configurable)
- Add concurrency and cache unit tests

The cache uses a trie-based structure for efficient prefix matching,
allowing prompts that share common prefixes to reuse cached KV states.
Thread safety is provided via threading.Lock.

New configuration options:
- max_cache_entries: Maximum LRU cache entries (default: 10)
- max_kv_size: Maximum KV cache size per entry (default: None)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Blightbow <blightbow@users.noreply.github.com>

* feat(mlx): add min_p and top_k sampler support

Add MinP field to proto (field 52) following the precedent set by
other non-OpenAI sampling parameters like TopK, TailFreeSamplingZ,
TypicalP, and Mirostat.

Changes:
- backend.proto: Add float MinP field for min-p sampling
- backend.py: Extract and pass min_p and top_k to mlx_lm sampler
  (top_k was in proto but not being passed)
- test.py: Fix test_sampling_params to use valid proto fields and
  switch to MLX-compatible model (mlx-community/Llama-3.2-1B-Instruct)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Blightbow <blightbow@users.noreply.github.com>

* refactor(mlx): move mlx_cache.py from common to mlx backend

The ThreadSafeLRUPromptCache is only used by the mlx backend. After
evaluating mlx-vlm, it was determined that the cache cannot be shared
because mlx-vlm's generate/stream_generate functions don't support
the prompt_cache parameter that mlx_lm provides.

- Move mlx_cache.py from backend/python/common/ to backend/python/mlx/
- Remove sys.path manipulation from backend.py and test.py
- Fix test assertion to expect "MLX model loaded successfully"

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Blightbow <blightbow@users.noreply.github.com>

* test(mlx): add comprehensive cache tests and document upstream behavior

Added comprehensive unit tests (test_mlx_cache.py) covering all cache
operation modes:
- Exact match
- Shorter prefix match
- Longer prefix match with trimming
- No match scenarios
- LRU eviction and access order
- Reference counting and deep copy behavior
- Multi-model namespacing
- Thread safety with data integrity verification

Documents upstream mlx_lm/server.py behavior: single-token prefixes are
deliberately not matched (uses > 0, not >= 0) to allow longer cached
sequences to be preferred for trimming. This is acceptable because real
prompts with chat templates are always many tokens.

Removed weak unit tests from test.py that only verified "no exception
thrown" rather than correctness.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Blightbow <blightbow@users.noreply.github.com>

* chore(mlx): remove unused MinP proto field

The MinP field was added to PredictOptions but is not populated by the
Go frontend/API. The MLX backend uses getattr with a default value,
so it works without the proto field.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Blightbow <blightbow@users.noreply.github.com>

---------

Signed-off-by: Blightbow <blightbow@users.noreply.github.com>
Co-authored-by: Blightbow <blightbow@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-16 11:27:46 +01:00

234 lines
9.4 KiB
Python

import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service.
This class contains methods to test the startup and shutdown of the gRPC service.
"""
def setUp(self):
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
time.sleep(10)
def tearDown(self) -> None:
self.service.terminate()
self.service.wait()
def test_server_startup(self):
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_load_model(self):
"""
This method tests if the model is loaded successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
self.assertEqual(response.message, "MLX model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test_text(self):
"""
This method tests if the embeddings are generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
resp = stub.Predict(req)
self.assertIsNotNone(resp.message)
except Exception as err:
print(err)
self.fail("text service failed")
finally:
self.tearDown()
def test_sampling_params(self):
"""
This method tests if all sampling parameters are correctly processed
NOTE: this does NOT test for correctness, just that we received a compatible response
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
req = backend_pb2.PredictOptions(
Prompt="The capital of France is",
TopP=0.8,
Tokens=50,
Temperature=0.7,
TopK=40,
PresencePenalty=0.1,
FrequencyPenalty=0.2,
MinP=0.05,
Seed=42,
StopPrompts=["\n"],
IgnoreEOS=True,
)
resp = stub.Predict(req)
self.assertIsNotNone(resp.message)
except Exception as err:
print(err)
self.fail("sampling params service failed")
finally:
self.tearDown()
def test_embedding(self):
"""
This method tests if the embeddings are generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
self.assertTrue(response.success)
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
embedding_response = stub.Embedding(embedding_request)
self.assertIsNotNone(embedding_response.embeddings)
# assert that is a list of floats
self.assertIsInstance(embedding_response.embeddings, list)
# assert that the list is not empty
self.assertTrue(len(embedding_response.embeddings) > 0)
except Exception as err:
print(err)
self.fail("Embedding service failed")
finally:
self.tearDown()
def test_concurrent_requests(self):
"""
This method tests that concurrent requests don't corrupt each other's cache state.
This is a regression test for the race condition in the original implementation.
"""
import concurrent.futures
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
def make_request(prompt):
req = backend_pb2.PredictOptions(Prompt=prompt, Tokens=20)
return stub.Predict(req)
# Run 5 concurrent requests with different prompts
prompts = [
"The capital of France is",
"The capital of Germany is",
"The capital of Italy is",
"The capital of Spain is",
"The capital of Portugal is",
]
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(make_request, p) for p in prompts]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
# All results should be non-empty
messages = [r.message for r in results]
self.assertTrue(all(len(m) > 0 for m in messages), "All requests should return non-empty responses")
print(f"Concurrent test passed: {len(messages)} responses received")
except Exception as err:
print(err)
self.fail("Concurrent requests test failed")
finally:
self.tearDown()
def test_cache_reuse(self):
"""
This method tests that repeated prompts reuse cached KV states.
The second request should benefit from the cached prompt processing.
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
prompt = "The quick brown fox jumps over the lazy dog. "
# First request - populates cache
req1 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
resp1 = stub.Predict(req1)
self.assertIsNotNone(resp1.message)
# Second request with same prompt - should reuse cache
req2 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
resp2 = stub.Predict(req2)
self.assertIsNotNone(resp2.message)
print(f"Cache reuse test passed: first={len(resp1.message)} bytes, second={len(resp2.message)} bytes")
except Exception as err:
print(err)
self.fail("Cache reuse test failed")
finally:
self.tearDown()
def test_prefix_cache_reuse(self):
"""
This method tests that prompts sharing a common prefix benefit from cached KV states.
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
# First request with base prompt
prompt_base = "Once upon a time in a land far away, "
req1 = backend_pb2.PredictOptions(Prompt=prompt_base, Tokens=10)
resp1 = stub.Predict(req1)
self.assertIsNotNone(resp1.message)
# Second request with extended prompt (same prefix)
prompt_extended = prompt_base + "there lived a brave knight who "
req2 = backend_pb2.PredictOptions(Prompt=prompt_extended, Tokens=10)
resp2 = stub.Predict(req2)
self.assertIsNotNone(resp2.message)
print(f"Prefix cache test passed: base={len(resp1.message)} bytes, extended={len(resp2.message)} bytes")
except Exception as err:
print(err)
self.fail("Prefix cache reuse test failed")
finally:
self.tearDown()
# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py