Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
d91ab9456e feat: add KV prefix caching for prompt reuse
Add LRU-based KV prefix cache to reuse computed prompt prefixes across
requests. When multiple requests share a common prefix (e.g., system prompt),
the cached KV state is reused instead of recomputing it.

Changes:
- Add KVPrefixCache class with LRU eviction in cache.py
- Integrate prefix cache into mlx_generate in generate.py
- Create cache instance in runner.py
- Add comprehensive tests in test_prefix_cache.py
- Update AGENTS.md with full type check command

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 19:02:26 +00:00
11 changed files with 407 additions and 217 deletions

View File

@@ -30,14 +30,17 @@ uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Type checking (strict mode) - MUST pass before committing
uv run basedpyright --project pyproject.toml
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
# Run all checks (do this before committing)
uv run basedpyright --project pyproject.toml && uv run ruff check && nix fmt
```
## Architecture
@@ -91,6 +94,10 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## File Locations
- **Downloaded models**: `~/.exo/models/` (NOT in huggingface cache)
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -60,39 +60,12 @@
return models;
});
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
// Auto-select the first available model if none is selected
$effect(() => {
const models = availableModels();
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -400,8 +400,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -761,10 +763,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -773,24 +771,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);

View File

@@ -1,5 +1,3 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -61,6 +67,8 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -373,8 +381,35 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -382,10 +417,16 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -401,11 +442,11 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -417,7 +458,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -425,7 +466,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -454,7 +495,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -462,7 +503,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -503,6 +544,8 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -519,16 +562,17 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -545,7 +589,10 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
return response
def _calculate_total_available_memory(self) -> Memory:

View File

@@ -425,15 +425,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,

View File

@@ -1,104 +1,189 @@
# type: ignore
# TODO: Fix this file, including types!
"""KV prefix cache for reusing computed prompt prefixes across requests."""
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
from typing import Any, Callable, Protocol
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
import numpy as np
from mlx_lm.models.cache import trim_prompt_cache
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.runner.bootstrap import logger
# Type alias for KV cache - the actual type is _BaseCache but it's private
KVCacheType = Any
class TokenizerProtocol(Protocol):
"""Protocol for tokenizers used with KVPrefixCache."""
bos_token: str | None
def encode(self, text: str, **kwargs: bool) -> list[int]: ...
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
"""Cache for common prompt prefixes to avoid re-processing.
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
Uses LRU eviction when capacity is reached. Stores tokenized prompts
and their corresponding KV caches for reuse.
"""
def __init__(self, max_size: int = 10):
"""Initialize prefix cache.
Args:
max_size: Maximum number of cached entries before LRU eviction.
"""
self.max_size = max_size
# OrderedDict maintains insertion order for LRU - most recent at end
# Key: token bytes, Value: (tokens as mx.array, KV cache)
self._cache: OrderedDict[bytes, tuple[mx.array, list[KVCacheType]]] = (
OrderedDict()
)
def _token_key(self, tokens: mx.array) -> bytes:
"""Create hashable key from token array."""
return np.array(tokens.tolist(), dtype=np.int32).tobytes()
def _encode_prompt(self, tokenizer: TokenizerProtocol, prompt: str) -> mx.array:
"""Tokenize prompt string to mx.array."""
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
return mx.array(tokenized)
def _find_best_prefix(self, tokens: mx.array) -> tuple[bytes | None, int]:
"""Find cached entry with longest matching prefix.
Returns:
Tuple of (cache_key, prefix_length). cache_key is None if no match found.
"""
best_key: bytes | None = None
best_length = 0
target_len = tokens.shape[0]
for key, (cached_tokens, _cache) in self._cache.items():
prefix_len = get_prefix_length(tokens, cached_tokens)
# Exact match - return immediately
if prefix_len == target_len and prefix_len == cached_tokens.shape[0]:
return key, prefix_len
# Better prefix match
if prefix_len > best_length:
best_key = key
best_length = prefix_len
return best_key, best_length
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
tokenizer: TokenizerProtocol,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
) -> tuple[list[KVCacheType], int]:
"""Get KV cache for prompt, reusing prefix if available.
best_snapshot_index, best_snapshot_length = None, 0
Args:
model: The model to create cache for.
tokenizer: Tokenizer for encoding prompt.
sampler: Sampler function for prefill.
prompt: The prompt string to process.
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
Returns:
Tuple of (kv_cache, tokens_reused). tokens_reused indicates how many
tokens were reused from cache (0 if no cache hit).
"""
tokens = self._encode_prompt(tokenizer, prompt)
target_len = int(tokens.shape[0])
if length == max_length:
return self.caches[i]
# Find best prefix match
best_key, prefix_len = self._find_best_prefix(tokens)
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_key is not None and prefix_len > 0:
cached_tokens, cached_kv = self._cache[best_key]
cached_len = int(cached_tokens.shape[0])
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
# Move to end (most recently used)
self._cache.move_to_end(best_key)
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
if prefix_len == target_len and prefix_len == cached_len:
# Exact match - return deepcopy directly
logger.debug(f"Prefix cache: exact match, reusing {prefix_len} tokens")
return deepcopy(cached_kv), prefix_len
# Partial match - need to trim and/or extend
prompt_cache = deepcopy(cached_kv)
if cached_len > prefix_len:
# Cached prompt is longer - trim to prefix length
num_to_trim = cached_len - prefix_len
trim_prompt_cache(prompt_cache, num_to_trim)
logger.debug(
f"Prefix cache: trimmed {num_to_trim} tokens from cached entry"
)
# Note: We don't prefill remaining tokens here - stream_generate will do it
# when processing the full prompt with this partial cache
return prompt_cache, prefix_len
# No cache hit - return fresh cache (stream_generate will prefill)
logger.debug(
f"Prefix cache: miss, will prefill {target_len} tokens during generation"
)
prompt_cache = make_kv_cache(model=model)
return prompt_cache, 0
def put(
self, tokenizer: TokenizerProtocol, prompt: str, cache: list[KVCacheType]
) -> None:
"""Store KV cache for prompt after generation completes.
Args:
tokenizer: Tokenizer for encoding prompt.
prompt: The prompt string that was processed.
cache: The KV cache to store.
"""
tokens = self._encode_prompt(tokenizer, prompt)
key = self._token_key(tokens)
# If already in cache, just move to end
if key in self._cache:
self._cache.move_to_end(key)
return
# Evict LRU entry if at capacity
if len(self._cache) >= self.max_size:
evicted_key, _ = self._cache.popitem(last=False)
logger.debug(
f"Prefix cache: evicted LRU entry ({len(evicted_key)} token bytes)"
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
# Store deepcopy
self._cache[key] = (tokens, deepcopy(cache))
logger.debug(f"Prefix cache: stored entry with {tokens.shape[0]} tokens")
return prompt_cache
def clear(self) -> None:
"""Clear all cached entries."""
self._cache.clear()
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
def __len__(self) -> int:
"""Return number of cached entries."""
return len(self._cache)
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Calculate length of matching prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -6,7 +6,6 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -19,6 +18,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -119,6 +119,7 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -135,8 +136,6 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -148,6 +147,20 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Get KV cache - either from prefix cache or fresh
tokens_reused = 0
if prefix_cache is not None:
caches, tokens_reused = prefix_cache.get_kv_cache(
model=model,
tokenizer=tokenizer,
sampler=sampler,
prompt=prompt,
)
if tokens_reused > 0:
logger.info(f"Prefix cache hit: reused {tokens_reused} tokens")
else:
caches = make_kv_cache(model=model)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
@@ -189,6 +202,9 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Store in prefix cache for future reuse
if prefix_cache is not None:
prefix_cache.put(tokenizer=tokenizer, prompt=prompt, cache=caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -20,7 +20,6 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -366,8 +365,6 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -399,11 +396,6 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -1,15 +1,6 @@
import time
from collections.abc import Generator
from functools import cache
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
@@ -48,6 +39,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -78,6 +70,7 @@ def main(
model = None
tokenizer = None
group = None
prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -119,6 +112,8 @@ def main(
)
model, tokenizer = load_mlx_items(bound_instance, group)
prefix_cache = KVPrefixCache(max_size=10)
logger.info("prefix cache initialized")
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -162,19 +157,12 @@ def main(
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
prefix_cache=prefix_cache,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
@@ -224,43 +212,6 @@ def main(
break
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -0,0 +1,141 @@
"""Tests for KVPrefixCache."""
import mlx.core as mx
import pytest
from exo.worker.engines.mlx.cache import (
KVCacheType,
KVPrefixCache,
TokenizerProtocol,
get_prefix_length,
)
class MockTokenizer(TokenizerProtocol):
"""Mock tokenizer that converts string to list of char codes."""
bos_token: str | None = None
def encode(self, text: str, **kwargs: bool) -> list[int]:
"""Encode text to list of character codes."""
del kwargs # unused
return [ord(c) for c in text]
class TestGetPrefixLength:
"""Tests for the core prefix matching algorithm."""
def test_identical_arrays(self) -> None:
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 5
def test_partial_match(self) -> None:
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 9, 9])
assert get_prefix_length(a, b) == 3
def test_no_match(self) -> None:
a = mx.array([1, 2, 3])
b = mx.array([9, 9, 9])
assert get_prefix_length(a, b) == 0
def test_different_lengths(self) -> None:
short = mx.array([1, 2, 3])
long = mx.array([1, 2, 3, 4, 5])
# Should return length of shorter when they match
assert get_prefix_length(short, long) == 3
assert get_prefix_length(long, short) == 3
def test_empty_array(self) -> None:
empty: mx.array = mx.array([])
tokens = mx.array([1, 2, 3])
assert get_prefix_length(empty, tokens) == 0
assert get_prefix_length(tokens, empty) == 0
class TestKVPrefixCache:
"""Tests for the KV prefix cache."""
@pytest.fixture
def tokenizer(self) -> MockTokenizer:
"""Mock tokenizer that converts string to list of char codes."""
return MockTokenizer()
@pytest.fixture
def fake_kv(self) -> list[KVCacheType]:
"""Fake KV cache for testing."""
return [object()]
def test_put_stores_entry(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=10)
cache.put(tokenizer, "hello", fake_kv)
assert len(cache) == 1
def test_put_same_prompt_twice_does_not_duplicate(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=10)
cache.put(tokenizer, "hello", fake_kv)
cache.put(tokenizer, "hello", fake_kv)
assert len(cache) == 1
def test_lru_eviction(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=2)
# Fill cache
cache.put(tokenizer, "first", fake_kv)
cache.put(tokenizer, "second", fake_kv)
assert len(cache) == 2
# Add third - should evict "first" (oldest)
cache.put(tokenizer, "third", fake_kv)
assert len(cache) == 2
# Add "first" again - if it was evicted, cache size stays 2
# If it wasn't evicted, this would be a no-op
cache.put(tokenizer, "first", fake_kv)
# Now add fourth - if "first" was re-added, size is still 2
cache.put(tokenizer, "fourth", fake_kv)
assert len(cache) == 2
def test_lru_access_refreshes_entry(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=2)
# Add two entries
cache.put(tokenizer, "first", fake_kv)
cache.put(tokenizer, "second", fake_kv)
# Access "first" again (moves to end of LRU)
cache.put(tokenizer, "first", fake_kv)
# Add third - should evict "second" now (oldest)
cache.put(tokenizer, "third", fake_kv)
# Add "second" again - this will add it as new entry
cache.put(tokenizer, "second", fake_kv)
# Now "first" is oldest, adding fourth should evict it
cache.put(tokenizer, "fourth", fake_kv)
# Cache should have "second" and "fourth", not "first"
assert len(cache) == 2
def test_clear(self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]) -> None:
cache = KVPrefixCache()
cache.put(tokenizer, "hello", fake_kv)
cache.put(tokenizer, "world", fake_kv)
assert len(cache) == 2
cache.clear()
assert len(cache) == 0