mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 13:42:20 -04:00
* feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
267 lines
9.1 KiB
Python
267 lines
9.1 KiB
Python
"""
|
|
Thread-safe LRU prompt cache for MLX-based backends.
|
|
|
|
Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
|
|
with thread-safety additions for LocalAI's gRPC backend.
|
|
|
|
Usage:
|
|
from mlx_cache import ThreadSafeLRUPromptCache
|
|
|
|
# In LoadModel:
|
|
self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
|
|
|
|
# In Predict/PredictStream:
|
|
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
|
|
# ... generate ...
|
|
self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
|
|
"""
|
|
import copy
|
|
import threading
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
|
|
@dataclass
|
|
class CacheEntry:
|
|
"""A cache entry with reference counting."""
|
|
prompt_cache: List[Any]
|
|
count: int
|
|
|
|
|
|
@dataclass
|
|
class SearchResult:
|
|
"""Result of searching the cache trie."""
|
|
model: Any
|
|
exact: Optional[List[int]]
|
|
shorter: Optional[List[int]]
|
|
longer: Optional[List[int]]
|
|
common_prefix: int
|
|
|
|
|
|
class ThreadSafeLRUPromptCache:
|
|
"""
|
|
Thread-safe LRU cache with prefix matching for prompt KV caches.
|
|
|
|
This cache stores KV caches keyed by token sequences and supports:
|
|
- Exact match: Return the cache for the exact token sequence
|
|
- Shorter prefix match: Return a cache for a prefix of the tokens
|
|
- Longer prefix match: If a longer sequence is cached and can be trimmed
|
|
- LRU eviction: When max_size is exceeded, evict least recently used
|
|
|
|
Thread safety is provided via a threading.Lock that protects all
|
|
cache operations.
|
|
|
|
Args:
|
|
max_size: Maximum number of cache entries (default: 10)
|
|
can_trim_fn: Optional function to check if a cache can be trimmed
|
|
trim_fn: Optional function to trim a cache
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_size: int = 10,
|
|
can_trim_fn: Optional[Any] = None,
|
|
trim_fn: Optional[Any] = None,
|
|
):
|
|
self.max_size = max_size
|
|
self._cache = {}
|
|
self._lru = deque()
|
|
self._lock = threading.Lock()
|
|
|
|
# Optional trim functions (for longer prefix reuse)
|
|
self._can_trim_fn = can_trim_fn
|
|
self._trim_fn = trim_fn
|
|
|
|
def _search(self, model, tokens: List[int]) -> SearchResult:
|
|
"""
|
|
Search the cache for a prompt cache. Return exact or close match.
|
|
|
|
The cache is organized as a trie where each node is keyed by a token.
|
|
This allows efficient prefix matching.
|
|
"""
|
|
if model not in self._cache:
|
|
return SearchResult(model, None, None, None, 0)
|
|
|
|
current = self._cache[model]
|
|
last_cache_index = -1
|
|
index = 0
|
|
|
|
# Traverse the trie following the token sequence
|
|
while index < len(tokens) and tokens[index] in current:
|
|
current = current[tokens[index]]
|
|
if "cache" in current:
|
|
last_cache_index = index
|
|
index += 1
|
|
|
|
# Exact match - no need to search for longer or shorter caches
|
|
if last_cache_index == len(tokens) - 1:
|
|
return SearchResult(model, tuple(tokens), None, None, 0)
|
|
|
|
# Find the shorter cache (a prefix that has a cache)
|
|
# Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
|
|
# Single-token prefixes are not matched, which allows longer cached
|
|
# sequences to be preferred for trimming. This is acceptable because
|
|
# real prompts with chat templates are always many tokens.
|
|
shorter = None
|
|
if last_cache_index > 0:
|
|
shorter = tuple(tokens[: last_cache_index + 1])
|
|
|
|
# Check for caches that are longer than our token sequence
|
|
longer = None
|
|
common_prefix = index
|
|
if index > 0 and last_cache_index <= 0:
|
|
best = None
|
|
stack = [(current, [])]
|
|
while stack:
|
|
current, extra = stack.pop()
|
|
if "cache" in current:
|
|
if best is None or len(extra) < len(best):
|
|
best = extra
|
|
else:
|
|
for tok in current:
|
|
stack.append((current[tok], extra + [tok]))
|
|
if best is not None:
|
|
longer = tuple(tokens[:index] + best)
|
|
|
|
return SearchResult(model, None, shorter, longer, common_prefix)
|
|
|
|
def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
|
"""Get a cache entry by traversing the trie."""
|
|
current = self._cache[model]
|
|
for tok in tokens:
|
|
current = current[tok]
|
|
return current["cache"]
|
|
|
|
def _delete(self, model, tokens: Tuple[int, ...]) -> None:
|
|
"""Delete a cache entry and clean up empty trie nodes."""
|
|
path = [self._cache[model]]
|
|
for tok in tokens:
|
|
path.append(path[-1][tok])
|
|
del path[-1]["cache"]
|
|
|
|
# Clean up empty nodes bottom-up
|
|
for i in reversed(range(len(tokens))):
|
|
d_prev, d, t = path[i], path[i + 1], tokens[i]
|
|
if len(d) > 0:
|
|
break
|
|
del d_prev[t]
|
|
|
|
def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
|
"""
|
|
Extract a cache entry for exclusive use.
|
|
|
|
If the entry has count > 1, deep copy and decrement.
|
|
If count == 1, remove from cache entirely.
|
|
"""
|
|
cache_entry = self._get(model, tokens)
|
|
if cache_entry.count == 1:
|
|
self._delete(model, tokens)
|
|
self._lru.remove((model, tokens))
|
|
return cache_entry
|
|
|
|
cache_entry.count -= 1
|
|
return CacheEntry(
|
|
copy.deepcopy(cache_entry.prompt_cache),
|
|
1,
|
|
)
|
|
|
|
def fetch_nearest_cache(
|
|
self, model, tokens: List[int]
|
|
) -> Tuple[Optional[List[Any]], List[int]]:
|
|
"""
|
|
Fetch the nearest cache for the given token sequence.
|
|
|
|
Thread-safe. Returns (cache, remaining_tokens) where:
|
|
- cache: The KV cache to use (or None if no cache found)
|
|
- remaining_tokens: Tokens that still need to be processed
|
|
|
|
Args:
|
|
model: Model identifier (used to namespace caches)
|
|
tokens: The full token sequence for the prompt
|
|
|
|
Returns:
|
|
Tuple of (prompt_cache, remaining_tokens)
|
|
"""
|
|
with self._lock:
|
|
tokens_tuple = tuple(tokens)
|
|
result = self._search(model, tokens)
|
|
|
|
# Exact match - extract and return
|
|
if result.exact is not None:
|
|
cache_entry = self._extract(result.model, result.exact)
|
|
return cache_entry.prompt_cache, []
|
|
|
|
# Shorter prefix match - extract and return remaining
|
|
if result.shorter is not None:
|
|
cache_entry = self._extract(result.model, result.shorter)
|
|
prefix_len = len(result.shorter)
|
|
return cache_entry.prompt_cache, list(tokens[prefix_len:])
|
|
|
|
# Longer prefix match - try to trim if possible
|
|
if result.longer is not None and self._can_trim_fn is not None:
|
|
cache_entry = self._get(result.model, result.longer)
|
|
if self._can_trim_fn(cache_entry.prompt_cache):
|
|
# Deep copy and trim
|
|
trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
|
|
prefix = min(len(tokens) - 1, result.common_prefix)
|
|
num_to_trim = len(result.longer) - prefix
|
|
if self._trim_fn is not None:
|
|
self._trim_fn(trimmed_cache, num_to_trim)
|
|
return trimmed_cache, list(tokens[prefix:])
|
|
|
|
# No match found
|
|
return None, list(tokens)
|
|
|
|
def insert_cache(
|
|
self, model, tokens: List[int], prompt_cache: List[Any]
|
|
) -> None:
|
|
"""
|
|
Insert a cache entry after generation completes.
|
|
|
|
Thread-safe. Handles LRU eviction if max_size is exceeded.
|
|
|
|
Args:
|
|
model: Model identifier (used to namespace caches)
|
|
tokens: The full token sequence (prompt + generated)
|
|
prompt_cache: The KV cache to store
|
|
"""
|
|
with self._lock:
|
|
tokens_tuple = tuple(tokens)
|
|
|
|
if model not in self._cache:
|
|
self._cache[model] = {}
|
|
current = self._cache[model]
|
|
|
|
# Build trie path
|
|
for tok in tokens_tuple:
|
|
if tok not in current:
|
|
current[tok] = {}
|
|
current = current[tok]
|
|
|
|
# Update or create entry
|
|
if "cache" in current:
|
|
current["cache"].count += 1
|
|
self._lru.remove((model, tokens_tuple))
|
|
else:
|
|
current["cache"] = CacheEntry(prompt_cache, 1)
|
|
|
|
# Update LRU order
|
|
self._lru.append((model, tokens_tuple))
|
|
|
|
# Evict if over capacity
|
|
if len(self._lru) > self.max_size:
|
|
evict_model, evict_tokens = self._lru.popleft()
|
|
self._delete(evict_model, evict_tokens)
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all cache entries. Thread-safe."""
|
|
with self._lock:
|
|
self._cache.clear()
|
|
self._lru.clear()
|
|
|
|
def __len__(self) -> int:
|
|
"""Return the number of cache entries. Thread-safe."""
|
|
with self._lock:
|
|
return len(self._lru)
|