mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 05:36:49 -04:00
feat(mlx-distributed): add new MLX-distributed backend (#8801)
* feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
734b6d391f
commit
a026277ab9
266
backend/python/mlx-distributed/mlx_cache.py
Normal file
266
backend/python/mlx-distributed/mlx_cache.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Thread-safe LRU prompt cache for MLX-based backends.
|
||||
|
||||
Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
|
||||
with thread-safety additions for LocalAI's gRPC backend.
|
||||
|
||||
Usage:
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
# In LoadModel:
|
||||
self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
# In Predict/PredictStream:
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
|
||||
# ... generate ...
|
||||
self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
|
||||
"""
|
||||
import copy
|
||||
import threading
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cache entry with reference counting."""
|
||||
prompt_cache: List[Any]
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Result of searching the cache trie."""
|
||||
model: Any
|
||||
exact: Optional[List[int]]
|
||||
shorter: Optional[List[int]]
|
||||
longer: Optional[List[int]]
|
||||
common_prefix: int
|
||||
|
||||
|
||||
class ThreadSafeLRUPromptCache:
|
||||
"""
|
||||
Thread-safe LRU cache with prefix matching for prompt KV caches.
|
||||
|
||||
This cache stores KV caches keyed by token sequences and supports:
|
||||
- Exact match: Return the cache for the exact token sequence
|
||||
- Shorter prefix match: Return a cache for a prefix of the tokens
|
||||
- Longer prefix match: If a longer sequence is cached and can be trimmed
|
||||
- LRU eviction: When max_size is exceeded, evict least recently used
|
||||
|
||||
Thread safety is provided via a threading.Lock that protects all
|
||||
cache operations.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of cache entries (default: 10)
|
||||
can_trim_fn: Optional function to check if a cache can be trimmed
|
||||
trim_fn: Optional function to trim a cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10,
|
||||
can_trim_fn: Optional[Any] = None,
|
||||
trim_fn: Optional[Any] = None,
|
||||
):
|
||||
self.max_size = max_size
|
||||
self._cache = {}
|
||||
self._lru = deque()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Optional trim functions (for longer prefix reuse)
|
||||
self._can_trim_fn = can_trim_fn
|
||||
self._trim_fn = trim_fn
|
||||
|
||||
def _search(self, model, tokens: List[int]) -> SearchResult:
|
||||
"""
|
||||
Search the cache for a prompt cache. Return exact or close match.
|
||||
|
||||
The cache is organized as a trie where each node is keyed by a token.
|
||||
This allows efficient prefix matching.
|
||||
"""
|
||||
if model not in self._cache:
|
||||
return SearchResult(model, None, None, None, 0)
|
||||
|
||||
current = self._cache[model]
|
||||
last_cache_index = -1
|
||||
index = 0
|
||||
|
||||
# Traverse the trie following the token sequence
|
||||
while index < len(tokens) and tokens[index] in current:
|
||||
current = current[tokens[index]]
|
||||
if "cache" in current:
|
||||
last_cache_index = index
|
||||
index += 1
|
||||
|
||||
# Exact match - no need to search for longer or shorter caches
|
||||
if last_cache_index == len(tokens) - 1:
|
||||
return SearchResult(model, tuple(tokens), None, None, 0)
|
||||
|
||||
# Find the shorter cache (a prefix that has a cache)
|
||||
# Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
|
||||
# Single-token prefixes are not matched, which allows longer cached
|
||||
# sequences to be preferred for trimming. This is acceptable because
|
||||
# real prompts with chat templates are always many tokens.
|
||||
shorter = None
|
||||
if last_cache_index > 0:
|
||||
shorter = tuple(tokens[: last_cache_index + 1])
|
||||
|
||||
# Check for caches that are longer than our token sequence
|
||||
longer = None
|
||||
common_prefix = index
|
||||
if index > 0 and last_cache_index <= 0:
|
||||
best = None
|
||||
stack = [(current, [])]
|
||||
while stack:
|
||||
current, extra = stack.pop()
|
||||
if "cache" in current:
|
||||
if best is None or len(extra) < len(best):
|
||||
best = extra
|
||||
else:
|
||||
for tok in current:
|
||||
stack.append((current[tok], extra + [tok]))
|
||||
if best is not None:
|
||||
longer = tuple(tokens[:index] + best)
|
||||
|
||||
return SearchResult(model, None, shorter, longer, common_prefix)
|
||||
|
||||
def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""Get a cache entry by traversing the trie."""
|
||||
current = self._cache[model]
|
||||
for tok in tokens:
|
||||
current = current[tok]
|
||||
return current["cache"]
|
||||
|
||||
def _delete(self, model, tokens: Tuple[int, ...]) -> None:
|
||||
"""Delete a cache entry and clean up empty trie nodes."""
|
||||
path = [self._cache[model]]
|
||||
for tok in tokens:
|
||||
path.append(path[-1][tok])
|
||||
del path[-1]["cache"]
|
||||
|
||||
# Clean up empty nodes bottom-up
|
||||
for i in reversed(range(len(tokens))):
|
||||
d_prev, d, t = path[i], path[i + 1], tokens[i]
|
||||
if len(d) > 0:
|
||||
break
|
||||
del d_prev[t]
|
||||
|
||||
def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""
|
||||
Extract a cache entry for exclusive use.
|
||||
|
||||
If the entry has count > 1, deep copy and decrement.
|
||||
If count == 1, remove from cache entirely.
|
||||
"""
|
||||
cache_entry = self._get(model, tokens)
|
||||
if cache_entry.count == 1:
|
||||
self._delete(model, tokens)
|
||||
self._lru.remove((model, tokens))
|
||||
return cache_entry
|
||||
|
||||
cache_entry.count -= 1
|
||||
return CacheEntry(
|
||||
copy.deepcopy(cache_entry.prompt_cache),
|
||||
1,
|
||||
)
|
||||
|
||||
def fetch_nearest_cache(
|
||||
self, model, tokens: List[int]
|
||||
) -> Tuple[Optional[List[Any]], List[int]]:
|
||||
"""
|
||||
Fetch the nearest cache for the given token sequence.
|
||||
|
||||
Thread-safe. Returns (cache, remaining_tokens) where:
|
||||
- cache: The KV cache to use (or None if no cache found)
|
||||
- remaining_tokens: Tokens that still need to be processed
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence for the prompt
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_cache, remaining_tokens)
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
result = self._search(model, tokens)
|
||||
|
||||
# Exact match - extract and return
|
||||
if result.exact is not None:
|
||||
cache_entry = self._extract(result.model, result.exact)
|
||||
return cache_entry.prompt_cache, []
|
||||
|
||||
# Shorter prefix match - extract and return remaining
|
||||
if result.shorter is not None:
|
||||
cache_entry = self._extract(result.model, result.shorter)
|
||||
prefix_len = len(result.shorter)
|
||||
return cache_entry.prompt_cache, list(tokens[prefix_len:])
|
||||
|
||||
# Longer prefix match - try to trim if possible
|
||||
if result.longer is not None and self._can_trim_fn is not None:
|
||||
cache_entry = self._get(result.model, result.longer)
|
||||
if self._can_trim_fn(cache_entry.prompt_cache):
|
||||
# Deep copy and trim
|
||||
trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
|
||||
prefix = min(len(tokens) - 1, result.common_prefix)
|
||||
num_to_trim = len(result.longer) - prefix
|
||||
if self._trim_fn is not None:
|
||||
self._trim_fn(trimmed_cache, num_to_trim)
|
||||
return trimmed_cache, list(tokens[prefix:])
|
||||
|
||||
# No match found
|
||||
return None, list(tokens)
|
||||
|
||||
def insert_cache(
|
||||
self, model, tokens: List[int], prompt_cache: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
Insert a cache entry after generation completes.
|
||||
|
||||
Thread-safe. Handles LRU eviction if max_size is exceeded.
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence (prompt + generated)
|
||||
prompt_cache: The KV cache to store
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
|
||||
if model not in self._cache:
|
||||
self._cache[model] = {}
|
||||
current = self._cache[model]
|
||||
|
||||
# Build trie path
|
||||
for tok in tokens_tuple:
|
||||
if tok not in current:
|
||||
current[tok] = {}
|
||||
current = current[tok]
|
||||
|
||||
# Update or create entry
|
||||
if "cache" in current:
|
||||
current["cache"].count += 1
|
||||
self._lru.remove((model, tokens_tuple))
|
||||
else:
|
||||
current["cache"] = CacheEntry(prompt_cache, 1)
|
||||
|
||||
# Update LRU order
|
||||
self._lru.append((model, tokens_tuple))
|
||||
|
||||
# Evict if over capacity
|
||||
if len(self._lru) > self.max_size:
|
||||
evict_model, evict_tokens = self._lru.popleft()
|
||||
self._delete(evict_model, evict_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._lru.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
return len(self._lru)
|
||||
Reference in New Issue
Block a user