mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-25 07:29:07 -05:00
* feat(mlx): add thread-safe LRU prompt cache Port mlx-lm's LRUPromptCache to fix race condition where concurrent requests corrupt shared KV cache state. The previous implementation used a single prompt_cache instance shared across all requests. Changes: - Add backend/python/common/mlx_cache.py with ThreadSafeLRUPromptCache - Modify backend.py to use per-request cache isolation via fetch/insert - Add prefix matching for cache reuse across similar prompts - Add LRU eviction (default 10 entries, configurable) - Add concurrency and cache unit tests The cache uses a trie-based structure for efficient prefix matching, allowing prompts that share common prefixes to reuse cached KV states. Thread safety is provided via threading.Lock. New configuration options: - max_cache_entries: Maximum LRU cache entries (default: 10) - max_kv_size: Maximum KV cache size per entry (default: None) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * feat(mlx): add min_p and top_k sampler support Add MinP field to proto (field 52) following the precedent set by other non-OpenAI sampling parameters like TopK, TailFreeSamplingZ, TypicalP, and Mirostat. Changes: - backend.proto: Add float MinP field for min-p sampling - backend.py: Extract and pass min_p and top_k to mlx_lm sampler (top_k was in proto but not being passed) - test.py: Fix test_sampling_params to use valid proto fields and switch to MLX-compatible model (mlx-community/Llama-3.2-1B-Instruct) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * refactor(mlx): move mlx_cache.py from common to mlx backend The ThreadSafeLRUPromptCache is only used by the mlx backend. After evaluating mlx-vlm, it was determined that the cache cannot be shared because mlx-vlm's generate/stream_generate functions don't support the prompt_cache parameter that mlx_lm provides. - Move mlx_cache.py from backend/python/common/ to backend/python/mlx/ - Remove sys.path manipulation from backend.py and test.py - Fix test assertion to expect "MLX model loaded successfully" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * test(mlx): add comprehensive cache tests and document upstream behavior Added comprehensive unit tests (test_mlx_cache.py) covering all cache operation modes: - Exact match - Shorter prefix match - Longer prefix match with trimming - No match scenarios - LRU eviction and access order - Reference counting and deep copy behavior - Multi-model namespacing - Thread safety with data integrity verification Documents upstream mlx_lm/server.py behavior: single-token prefixes are deliberately not matched (uses > 0, not >= 0) to allow longer cached sequences to be preferred for trimming. This is acceptable because real prompts with chat templates are always many tokens. Removed weak unit tests from test.py that only verified "no exception thrown" rather than correctness. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * chore(mlx): remove unused MinP proto field The MinP field was added to PredictOptions but is not populated by the Go frontend/API. The MLX backend uses getattr with a default value, so it works without the proto field. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> --------- Signed-off-by: Blightbow <blightbow@users.noreply.github.com> Co-authored-by: Blightbow <blightbow@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
451 lines
16 KiB
Python
451 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
import asyncio
|
|
from concurrent import futures
|
|
import argparse
|
|
import signal
|
|
import sys
|
|
import os
|
|
from typing import List
|
|
import time
|
|
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
import grpc
|
|
from mlx_lm import load, generate, stream_generate
|
|
from mlx_lm.sample_utils import make_sampler
|
|
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
|
import mlx.core as mx
|
|
import base64
|
|
import io
|
|
|
|
from mlx_cache import ThreadSafeLRUPromptCache
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
|
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
|
|
def is_float(s):
|
|
"""Check if a string can be converted to float."""
|
|
try:
|
|
float(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
def is_int(s):
|
|
"""Check if a string can be converted to int."""
|
|
try:
|
|
int(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
# Implement the BackendServicer class with the service methods
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
"""
|
|
A gRPC servicer that implements the Backend service defined in backend.proto.
|
|
"""
|
|
|
|
def Health(self, request, context):
|
|
"""
|
|
Returns a health check message.
|
|
|
|
Args:
|
|
request: The health check request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Reply: The health check reply.
|
|
"""
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
|
|
async def LoadModel(self, request, context):
|
|
"""
|
|
Loads a language model using MLX.
|
|
|
|
Args:
|
|
request: The load model request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Result: The load model result.
|
|
"""
|
|
try:
|
|
print(f"Loading MLX model: {request.Model}", file=sys.stderr)
|
|
print(f"Request: {request}", file=sys.stderr)
|
|
|
|
# Parse options like in the diffusers backend
|
|
options = request.Options
|
|
self.options = {}
|
|
|
|
# The options are a list of strings in this form optname:optvalue
|
|
# We store all the options in a dict for later use
|
|
for opt in options:
|
|
if ":" not in opt:
|
|
continue
|
|
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
|
|
|
# Convert numeric values to appropriate types
|
|
if is_float(value):
|
|
value = float(value)
|
|
elif is_int(value):
|
|
value = int(value)
|
|
elif value.lower() in ["true", "false"]:
|
|
value = value.lower() == "true"
|
|
|
|
self.options[key] = value
|
|
|
|
print(f"Options: {self.options}", file=sys.stderr)
|
|
|
|
# Build tokenizer config for MLX using options
|
|
tokenizer_config = {}
|
|
|
|
# Handle trust_remote_code from request or options
|
|
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
|
|
tokenizer_config["trust_remote_code"] = True
|
|
|
|
# Handle EOS token from options
|
|
if "eos_token" in self.options:
|
|
tokenizer_config["eos_token"] = self.options["eos_token"]
|
|
|
|
# Handle other tokenizer config options
|
|
for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]:
|
|
if key in self.options:
|
|
tokenizer_config[key] = self.options[key]
|
|
|
|
# Load model and tokenizer using MLX
|
|
if tokenizer_config:
|
|
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
|
|
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
|
|
else:
|
|
self.model, self.tokenizer = load(request.Model)
|
|
|
|
# Initialize thread-safe LRU prompt cache for efficient generation
|
|
max_cache_entries = self.options.get("max_cache_entries", 10)
|
|
self.max_kv_size = self.options.get("max_kv_size", None)
|
|
self.model_key = request.Model
|
|
self.lru_cache = ThreadSafeLRUPromptCache(
|
|
max_size=max_cache_entries,
|
|
can_trim_fn=can_trim_prompt_cache,
|
|
trim_fn=trim_prompt_cache,
|
|
)
|
|
|
|
except Exception as err:
|
|
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}")
|
|
|
|
print("MLX model loaded successfully", file=sys.stderr)
|
|
return backend_pb2.Result(message="MLX model loaded successfully", success=True)
|
|
|
|
async def Predict(self, request, context):
|
|
"""
|
|
Generates text based on the given prompt and sampling parameters using MLX.
|
|
|
|
Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.
|
|
|
|
Args:
|
|
request: The predict request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Reply: The predict result.
|
|
"""
|
|
prompt_cache = None
|
|
cache_key = None
|
|
|
|
try:
|
|
# Prepare the prompt and tokenize for cache key
|
|
prompt_text = self._prepare_prompt(request)
|
|
cache_key = self._get_tokens_from_prompt(prompt_text)
|
|
|
|
# Fetch nearest cache (exact, shorter prefix, or create new)
|
|
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
|
self.model_key, cache_key
|
|
)
|
|
if prompt_cache is None:
|
|
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
|
remaining_tokens = cache_key
|
|
|
|
# Build generation parameters using request attributes and options
|
|
max_tokens, sampler_params = self._build_generation_params(request)
|
|
|
|
print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
|
|
|
# Create sampler with parameters
|
|
sampler = make_sampler(**sampler_params)
|
|
|
|
# Use stream_generate to track generated tokens for cache key
|
|
generated_text = []
|
|
for response in stream_generate(
|
|
self.model,
|
|
self.tokenizer,
|
|
prompt=remaining_tokens if remaining_tokens else cache_key,
|
|
max_tokens=max_tokens,
|
|
sampler=sampler,
|
|
prompt_cache=prompt_cache,
|
|
):
|
|
generated_text.append(response.text)
|
|
cache_key.append(response.token)
|
|
|
|
# Insert completed cache
|
|
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
|
|
|
return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))
|
|
|
|
except Exception as e:
|
|
print(f"Error in MLX Predict: {e}", file=sys.stderr)
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"Generation failed: {str(e)}")
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
|
|
|
def Embedding(self, request, context):
|
|
"""
|
|
A gRPC method that calculates embeddings for a given sentence.
|
|
|
|
Note: MLX-LM doesn't support embeddings directly. This method returns an error.
|
|
|
|
Args:
|
|
request: An EmbeddingRequest object that contains the request parameters.
|
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
|
|
Returns:
|
|
An EmbeddingResult object that contains the calculated embeddings.
|
|
"""
|
|
print("Embeddings not supported in MLX backend", file=sys.stderr)
|
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
context.set_details("Embeddings are not supported in the MLX backend.")
|
|
return backend_pb2.EmbeddingResult()
|
|
|
|
async def PredictStream(self, request, context):
|
|
"""
|
|
Generates text based on the given prompt and sampling parameters, and streams the results using MLX.
|
|
|
|
Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.
|
|
|
|
Args:
|
|
request: The predict stream request.
|
|
context: The gRPC context.
|
|
|
|
Yields:
|
|
backend_pb2.Reply: Streaming predict results.
|
|
"""
|
|
prompt_cache = None
|
|
cache_key = None
|
|
|
|
try:
|
|
# Prepare the prompt and tokenize for cache key
|
|
prompt_text = self._prepare_prompt(request)
|
|
cache_key = self._get_tokens_from_prompt(prompt_text)
|
|
|
|
# Fetch nearest cache (exact, shorter prefix, or create new)
|
|
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
|
self.model_key, cache_key
|
|
)
|
|
if prompt_cache is None:
|
|
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
|
remaining_tokens = cache_key
|
|
|
|
# Build generation parameters using request attributes and options
|
|
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
|
|
|
print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
|
|
|
# Create sampler with parameters
|
|
sampler = make_sampler(**sampler_params)
|
|
|
|
# Stream text generation using MLX with proper parameters
|
|
for response in stream_generate(
|
|
self.model,
|
|
self.tokenizer,
|
|
prompt=remaining_tokens if remaining_tokens else cache_key,
|
|
max_tokens=max_tokens,
|
|
sampler=sampler,
|
|
prompt_cache=prompt_cache,
|
|
):
|
|
cache_key.append(response.token)
|
|
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
|
|
|
except Exception as e:
|
|
print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"Streaming generation failed: {str(e)}")
|
|
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
|
|
|
finally:
|
|
# Always insert cache, even on interruption
|
|
if prompt_cache is not None and cache_key is not None:
|
|
try:
|
|
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
|
except Exception as e:
|
|
print(f"Error inserting cache: {e}", file=sys.stderr)
|
|
|
|
def _prepare_prompt(self, request):
|
|
"""
|
|
Prepare the prompt for MLX generation, handling chat templates if needed.
|
|
|
|
Args:
|
|
request: The gRPC request containing prompt and message information.
|
|
|
|
Returns:
|
|
str: The prepared prompt.
|
|
"""
|
|
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
|
|
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
|
# Convert gRPC messages to the format expected by apply_chat_template
|
|
messages = []
|
|
for msg in request.Messages:
|
|
messages.append({"role": msg.role, "content": msg.content})
|
|
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True
|
|
)
|
|
return prompt
|
|
else:
|
|
return request.Prompt
|
|
|
|
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
|
"""
|
|
Tokenize prompt text for cache key generation.
|
|
|
|
Args:
|
|
prompt_text: The prompt string to tokenize.
|
|
|
|
Returns:
|
|
List[int]: List of token IDs.
|
|
"""
|
|
tokens = self.tokenizer.encode(prompt_text)
|
|
if hasattr(tokens, 'tolist'):
|
|
return tokens.tolist()
|
|
return list(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
def _build_generation_params(self, request, default_max_tokens=200):
|
|
"""
|
|
Build generation parameters from request attributes and options.
|
|
|
|
Args:
|
|
request: The gRPC request.
|
|
default_max_tokens: Default max_tokens if not specified.
|
|
|
|
Returns:
|
|
tuple: (max_tokens, sampler_params dict)
|
|
"""
|
|
# Extract max_tokens
|
|
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
|
if max_tokens == 0:
|
|
max_tokens = default_max_tokens
|
|
|
|
# Extract sampler parameters from request attributes
|
|
temp = getattr(request, 'Temperature', 0.0)
|
|
if temp == 0.0:
|
|
temp = 0.6 # Default temperature
|
|
|
|
top_p = getattr(request, 'TopP', 0.0)
|
|
if top_p == 0.0:
|
|
top_p = 1.0 # Default top_p
|
|
|
|
min_p = getattr(request, 'MinP', 0.0)
|
|
# min_p default of 0.0 means disabled (no filtering)
|
|
|
|
top_k = getattr(request, 'TopK', 0)
|
|
# top_k default of 0 means disabled (no filtering)
|
|
|
|
# Initialize sampler parameters
|
|
sampler_params = {
|
|
'temp': temp,
|
|
'top_p': top_p,
|
|
'min_p': min_p,
|
|
'top_k': top_k,
|
|
'xtc_threshold': 0.0,
|
|
'xtc_probability': 0.0,
|
|
}
|
|
|
|
# Add seed if specified
|
|
seed = getattr(request, 'Seed', 0)
|
|
if seed != 0:
|
|
mx.random.seed(seed)
|
|
|
|
# Override with options if available
|
|
if hasattr(self, 'options'):
|
|
# Max tokens from options
|
|
if 'max_tokens' in self.options:
|
|
max_tokens = self.options['max_tokens']
|
|
|
|
# Sampler parameters from options
|
|
sampler_option_mapping = {
|
|
'temp': 'temp',
|
|
'temperature': 'temp', # alias
|
|
'top_p': 'top_p',
|
|
'min_p': 'min_p',
|
|
'top_k': 'top_k',
|
|
'xtc_threshold': 'xtc_threshold',
|
|
'xtc_probability': 'xtc_probability',
|
|
}
|
|
|
|
for option_key, param_key in sampler_option_mapping.items():
|
|
if option_key in self.options:
|
|
sampler_params[param_key] = self.options[option_key]
|
|
|
|
# Handle seed from options
|
|
if 'seed' in self.options:
|
|
mx.random.seed(self.options['seed'])
|
|
|
|
# Special tokens for XTC sampling (if tokenizer has eos_token_ids)
|
|
xtc_special_tokens = []
|
|
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
|
|
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
|
|
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
|
xtc_special_tokens = [self.tokenizer.eos_token_id]
|
|
|
|
# Add newline token if available
|
|
try:
|
|
newline_tokens = self.tokenizer.encode("\n")
|
|
xtc_special_tokens.extend(newline_tokens)
|
|
except:
|
|
pass # Skip if encoding fails
|
|
|
|
sampler_params['xtc_special_tokens'] = xtc_special_tokens
|
|
|
|
return max_tokens, sampler_params
|
|
|
|
async def serve(address):
|
|
# Start asyncio gRPC server
|
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
options=[
|
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
|
])
|
|
# Add the servicer to the server
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
# Bind the server to the address
|
|
server.add_insecure_port(address)
|
|
|
|
# Gracefully shutdown the server on SIGTERM or SIGINT
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(
|
|
sig, lambda: asyncio.ensure_future(server.stop(5))
|
|
)
|
|
|
|
# Start the server
|
|
await server.start()
|
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
# Wait for the server to be terminated
|
|
await server.wait_for_termination()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
parser.add_argument(
|
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
asyncio.run(serve(args.addr))
|