Compare commits

..

2 Commits

Author SHA1 Message Date
Alex Cheema
d91ab9456e feat: add KV prefix caching for prompt reuse
Add LRU-based KV prefix cache to reuse computed prompt prefixes across
requests. When multiple requests share a common prefix (e.g., system prompt),
the cached KV state is reused instead of recomputing it.

Changes:
- Add KVPrefixCache class with LRU eviction in cache.py
- Integrate prefix cache into mlx_generate in generate.py
- Create cache instance in runner.py
- Add comprehensive tests in test_prefix_cache.py
- Update AGENTS.md with full type check command

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 19:02:26 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
45 changed files with 1635 additions and 1386 deletions

View File

@@ -30,14 +30,17 @@ uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Type checking (strict mode) - MUST pass before committing
uv run basedpyright --project pyproject.toml
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
# Run all checks (do this before committing)
uv run basedpyright --project pyproject.toml && uv run ruff check && nix fmt
```
## Architecture
@@ -91,6 +94,10 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## File Locations
- **Downloaded models**: `~/.exo/models/` (NOT in huggingface cache)
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -23,7 +23,6 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -1,15 +0,0 @@
short_id = "deepseek-v3.1-4bit"
model_id = "mlx-community/DeepSeek-V3.1-4bit"
name = "DeepSeek V3.1 (4-bit)"
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/DeepSeek-V3.1-4bit"
pretty_name = "DeepSeek V3.1 (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 405874409472

View File

@@ -1,15 +0,0 @@
short_id = "deepseek-v3.1-8bit"
model_id = "mlx-community/DeepSeek-V3.1-8bit"
name = "DeepSeek V3.1 (8-bit)"
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/DeepSeek-V3.1-8bit"
pretty_name = "DeepSeek V3.1 (8-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 765577920512

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.5-air-8bit"
model_id = "mlx-community/GLM-4.5-Air-8bit"
name = "GLM 4.5 Air 8bit"
description = "GLM 4.5 Air 8bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.5-Air-8bit"
pretty_name = "GLM 4.5 Air 8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
[metadata.storage_size]
in_bytes = 122406567936

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.5-air-bf16"
model_id = "mlx-community/GLM-4.5-Air-bf16"
name = "GLM 4.5 Air bf16"
description = "GLM 4.5 Air bf16"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.5-Air-bf16"
pretty_name = "GLM 4.5 Air bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 229780750336

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.7-4bit"
model_id = "mlx-community/GLM-4.7-4bit"
name = "GLM 4.7 4bit"
description = "GLM 4.7 4bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-4bit"
pretty_name = "GLM 4.7 4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 198556925568

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.7-6bit"
model_id = "mlx-community/GLM-4.7-6bit"
name = "GLM 4.7 6bit"
description = "GLM 4.7 6bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-6bit"
pretty_name = "GLM 4.7 6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 286737579648

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.7-8bit-gs32"
model_id = "mlx-community/GLM-4.7-8bit-gs32"
name = "GLM 4.7 8bit (gs32)"
description = "GLM 4.7 8bit (gs32)"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-8bit-gs32"
pretty_name = "GLM 4.7 8bit (gs32)"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 396963397248

View File

@@ -1,15 +0,0 @@
short_id = "gpt-oss-120b-MXFP4-Q8"
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
description = "OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon."
tags = []
[metadata]
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
pretty_name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
n_layers = 36
hidden_size = 2880
supports_tensor = true
[metadata.storage_size]
in_bytes = 70652212224

View File

@@ -1,15 +0,0 @@
short_id = "gpt-oss-20b-4bit"
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
description = "OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization."
tags = []
[metadata]
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
pretty_name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
n_layers = 24
hidden_size = 2880
supports_tensor = true
[metadata.storage_size]
in_bytes = 12025908224

View File

@@ -1,15 +0,0 @@
short_id = "kimi-k2-instruct-4bit"
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
name = "Kimi K2 Instruct (4-bit)"
description = "Kimi K2 is a large language model trained on the Kimi K2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
pretty_name = "Kimi K2 Instruct (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 620622774272

View File

@@ -1,15 +0,0 @@
short_id = "kimi-k2-thinking"
model_id = "mlx-community/Kimi-K2-Thinking"
name = "Kimi K2 Thinking (4-bit)"
description = "Kimi K2 Thinking is the latest, most capable version of open-source thinking model."
tags = []
[metadata]
model_id = "mlx-community/Kimi-K2-Thinking"
pretty_name = "Kimi K2 Thinking (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 706522120192

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-70b"
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
name = "Llama 3.1 70B (4-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
pretty_name = "Llama 3.1 70B (4-bit)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 40652242944

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-8b-8bit"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
name = "Llama 3.1 8B (8-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
pretty_name = "Llama 3.1 8B (8-bit)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 8954839040

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-8b-bf16"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
name = "Llama 3.1 8B (BF16)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
pretty_name = "Llama 3.1 8B (BF16)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 16882073600

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-8b"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
name = "Llama 3.1 8B (4-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
pretty_name = "Llama 3.1 8B (4-bit)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 4637851648

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.2-1b"
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
name = "Llama 3.2 1B (4-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
pretty_name = "Llama 3.2 1B (4-bit)"
n_layers = 16
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 729808896

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.2-3b-8bit"
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
name = "Llama 3.2 3B (8-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
pretty_name = "Llama 3.2 3B (8-bit)"
n_layers = 28
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 3501195264

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.2-3b"
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
name = "Llama 3.2 3B (4-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
pretty_name = "Llama 3.2 3B (4-bit)"
n_layers = 28
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 1863319552

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.3-70b-8bit"
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
name = "Llama 3.3 70B (8-bit)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
pretty_name = "Llama 3.3 70B (8-bit)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 76799803392

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.3-70b-fp16"
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
name = "Llama 3.3 70B (FP16)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
pretty_name = "Llama 3.3 70B (FP16)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 144383672320

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.3-70b"
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
name = "Llama 3.3 70B (4-bit)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
pretty_name = "Llama 3.3 70B"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 40652242944

View File

@@ -1,15 +0,0 @@
short_id = "minimax-m2.1-3bit"
model_id = "mlx-community/MiniMax-M2.1-3bit"
name = "MiniMax M2.1 3bit"
description = "MiniMax M2.1 3bit"
tags = []
[metadata]
model_id = "mlx-community/MiniMax-M2.1-3bit"
pretty_name = "MiniMax M2.1 3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 100086644736

View File

@@ -1,15 +0,0 @@
short_id = "minimax-m2.1-8bit"
model_id = "mlx-community/MiniMax-M2.1-8bit"
name = "MiniMax M2.1 8bit"
description = "MiniMax M2.1 8bit"
tags = []
[metadata]
model_id = "mlx-community/MiniMax-M2.1-8bit"
pretty_name = "MiniMax M2.1 8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 242986745856

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-0.6b-8bit"
model_id = "mlx-community/Qwen3-0.6B-8bit"
name = "Qwen3 0.6B (8-bit)"
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-0.6B-8bit"
pretty_name = "Qwen3 0.6B (8-bit)"
n_layers = 28
hidden_size = 1024
supports_tensor = false
[metadata.storage_size]
in_bytes = 698351616

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-0.6b"
model_id = "mlx-community/Qwen3-0.6B-4bit"
name = "Qwen3 0.6B (4-bit)"
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-0.6B-4bit"
pretty_name = "Qwen3 0.6B (4-bit)"
n_layers = 28
hidden_size = 1024
supports_tensor = false
[metadata.storage_size]
in_bytes = 342884352

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-235b-a22b-4bit"
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
name = "Qwen3 235B A22B (4-bit)"
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
pretty_name = "Qwen3 235B A22B (4-bit)"
n_layers = 94
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 141733920768

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-235b-a22b-8bit"
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
name = "Qwen3 235B A22B (8-bit)"
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
pretty_name = "Qwen3 235B A22B (8-bit)"
n_layers = 94
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 268435456000

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-30b-8bit"
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
name = "Qwen3 30B A3B (8-bit)"
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
pretty_name = "Qwen3 30B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 33279705088

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-30b"
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
name = "Qwen3 30B A3B (4-bit)"
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
pretty_name = "Qwen3 30B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 17612931072

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-4bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
name = "Qwen3 80B A3B (4-bit)"
description = "Qwen3 80B"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
pretty_name = "Qwen3 80B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 46976204800

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-8bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
name = "Qwen3 80B A3B (8-bit)"
description = "Qwen3 80B"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
pretty_name = "Qwen3 80B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-thinking-4bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
name = "Qwen3 80B A3B Thinking (4-bit)"
description = "Qwen3 80B Reasoning model"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
pretty_name = "Qwen3 80B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-thinking-8bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
name = "Qwen3 80B A3B Thinking (8-bit)"
description = "Qwen3 80B Reasoning model"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
pretty_name = "Qwen3 80B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-coder-480b-a35b-4bit"
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
name = "Qwen3 Coder 480B A35B (4-bit)"
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
pretty_name = "Qwen3 Coder 480B A35B (4-bit)"
n_layers = 62
hidden_size = 6144
supports_tensor = true
[metadata.storage_size]
in_bytes = 289910292480

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-coder-480b-a35b-8bit"
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
name = "Qwen3 Coder 480B A35B (8-bit)"
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
pretty_name = "Qwen3 Coder 480B A35B (8-bit)"
n_layers = 62
hidden_size = 6144
supports_tensor = true
[metadata.storage_size]
in_bytes = 579820584960

View File

@@ -1,8 +1,5 @@
from anyio import Path, open_file
import tomlkit
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.models.model_meta import get_model_meta
from exo.utils.pydantic_ext import CamelCaseModel
@@ -14,27 +11,542 @@ class ModelCard(CamelCaseModel):
tags: list[str]
metadata: ModelMetadata
@staticmethod
async def load(path: Path) -> "ModelCard":
async with await open_file(path) as f:
data = await f.read()
py = tomlkit.loads(data)
return ModelCard.model_validate(py)
async def save(self, path: Path):
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def from_hf(model_id: str) -> "ModelCard":
short_name = model_id.split("/")[-1]
return ModelCard(
short_id=short_name,
model_id=ModelId(model_id),
name=short_name,
description=f"Custom model from {model_id}",
tags=[],
metadata=await get_model_meta(model_id),
)
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -6,6 +6,7 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
@@ -107,13 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_id,
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=False,
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -1,104 +1,189 @@
# type: ignore
# TODO: Fix this file, including types!
"""KV prefix cache for reusing computed prompt prefixes across requests."""
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
from typing import Any, Callable, Protocol
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
import numpy as np
from mlx_lm.models.cache import trim_prompt_cache
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.runner.bootstrap import logger
# Type alias for KV cache - the actual type is _BaseCache but it's private
KVCacheType = Any
class TokenizerProtocol(Protocol):
"""Protocol for tokenizers used with KVPrefixCache."""
bos_token: str | None
def encode(self, text: str, **kwargs: bool) -> list[int]: ...
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
"""Cache for common prompt prefixes to avoid re-processing.
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
Uses LRU eviction when capacity is reached. Stores tokenized prompts
and their corresponding KV caches for reuse.
"""
def __init__(self, max_size: int = 10):
"""Initialize prefix cache.
Args:
max_size: Maximum number of cached entries before LRU eviction.
"""
self.max_size = max_size
# OrderedDict maintains insertion order for LRU - most recent at end
# Key: token bytes, Value: (tokens as mx.array, KV cache)
self._cache: OrderedDict[bytes, tuple[mx.array, list[KVCacheType]]] = (
OrderedDict()
)
def _token_key(self, tokens: mx.array) -> bytes:
"""Create hashable key from token array."""
return np.array(tokens.tolist(), dtype=np.int32).tobytes()
def _encode_prompt(self, tokenizer: TokenizerProtocol, prompt: str) -> mx.array:
"""Tokenize prompt string to mx.array."""
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
return mx.array(tokenized)
def _find_best_prefix(self, tokens: mx.array) -> tuple[bytes | None, int]:
"""Find cached entry with longest matching prefix.
Returns:
Tuple of (cache_key, prefix_length). cache_key is None if no match found.
"""
best_key: bytes | None = None
best_length = 0
target_len = tokens.shape[0]
for key, (cached_tokens, _cache) in self._cache.items():
prefix_len = get_prefix_length(tokens, cached_tokens)
# Exact match - return immediately
if prefix_len == target_len and prefix_len == cached_tokens.shape[0]:
return key, prefix_len
# Better prefix match
if prefix_len > best_length:
best_key = key
best_length = prefix_len
return best_key, best_length
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
tokenizer: TokenizerProtocol,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
) -> tuple[list[KVCacheType], int]:
"""Get KV cache for prompt, reusing prefix if available.
best_snapshot_index, best_snapshot_length = None, 0
Args:
model: The model to create cache for.
tokenizer: Tokenizer for encoding prompt.
sampler: Sampler function for prefill.
prompt: The prompt string to process.
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
Returns:
Tuple of (kv_cache, tokens_reused). tokens_reused indicates how many
tokens were reused from cache (0 if no cache hit).
"""
tokens = self._encode_prompt(tokenizer, prompt)
target_len = int(tokens.shape[0])
if length == max_length:
return self.caches[i]
# Find best prefix match
best_key, prefix_len = self._find_best_prefix(tokens)
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_key is not None and prefix_len > 0:
cached_tokens, cached_kv = self._cache[best_key]
cached_len = int(cached_tokens.shape[0])
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
# Move to end (most recently used)
self._cache.move_to_end(best_key)
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
if prefix_len == target_len and prefix_len == cached_len:
# Exact match - return deepcopy directly
logger.debug(f"Prefix cache: exact match, reusing {prefix_len} tokens")
return deepcopy(cached_kv), prefix_len
# Partial match - need to trim and/or extend
prompt_cache = deepcopy(cached_kv)
if cached_len > prefix_len:
# Cached prompt is longer - trim to prefix length
num_to_trim = cached_len - prefix_len
trim_prompt_cache(prompt_cache, num_to_trim)
logger.debug(
f"Prefix cache: trimmed {num_to_trim} tokens from cached entry"
)
# Note: We don't prefill remaining tokens here - stream_generate will do it
# when processing the full prompt with this partial cache
return prompt_cache, prefix_len
# No cache hit - return fresh cache (stream_generate will prefill)
logger.debug(
f"Prefix cache: miss, will prefill {target_len} tokens during generation"
)
prompt_cache = make_kv_cache(model=model)
return prompt_cache, 0
def put(
self, tokenizer: TokenizerProtocol, prompt: str, cache: list[KVCacheType]
) -> None:
"""Store KV cache for prompt after generation completes.
Args:
tokenizer: Tokenizer for encoding prompt.
prompt: The prompt string that was processed.
cache: The KV cache to store.
"""
tokens = self._encode_prompt(tokenizer, prompt)
key = self._token_key(tokens)
# If already in cache, just move to end
if key in self._cache:
self._cache.move_to_end(key)
return
# Evict LRU entry if at capacity
if len(self._cache) >= self.max_size:
evicted_key, _ = self._cache.popitem(last=False)
logger.debug(
f"Prefix cache: evicted LRU entry ({len(evicted_key)} token bytes)"
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
# Store deepcopy
self._cache[key] = (tokens, deepcopy(cache))
logger.debug(f"Prefix cache: stored entry with {tokens.shape[0]} tokens")
return prompt_cache
def clear(self) -> None:
"""Clear all cached entries."""
self._cache.clear()
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
def __len__(self) -> int:
"""Return number of cached entries."""
return len(self._cache)
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Calculate length of matching prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -6,7 +6,6 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -19,6 +18,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -119,6 +119,7 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -135,8 +136,6 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -148,6 +147,20 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Get KV cache - either from prefix cache or fresh
tokens_reused = 0
if prefix_cache is not None:
caches, tokens_reused = prefix_cache.get_kv_cache(
model=model,
tokenizer=tokenizer,
sampler=sampler,
prompt=prompt,
)
if tokens_reused > 0:
logger.info(f"Prefix cache hit: reused {tokens_reused} tokens")
else:
caches = make_kv_cache(model=model)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
@@ -189,6 +202,9 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Store in prefix cache for future reuse
if prefix_cache is not None:
prefix_cache.put(tokenizer=tokenizer, prompt=prompt, cache=caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -39,6 +39,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -69,6 +70,7 @@ def main(
model = None
tokenizer = None
group = None
prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -110,6 +112,8 @@ def main(
)
model, tokenizer = load_mlx_items(bound_instance, group)
prefix_cache = KVPrefixCache(max_size=10)
logger.info("prefix cache initialized")
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -157,6 +161,7 @@ def main(
model=model,
tokenizer=tokenizer,
task=task_params,
prefix_cache=prefix_cache,
):
match response:
case GenerationResponse():

View File

@@ -0,0 +1,141 @@
"""Tests for KVPrefixCache."""
import mlx.core as mx
import pytest
from exo.worker.engines.mlx.cache import (
KVCacheType,
KVPrefixCache,
TokenizerProtocol,
get_prefix_length,
)
class MockTokenizer(TokenizerProtocol):
"""Mock tokenizer that converts string to list of char codes."""
bos_token: str | None = None
def encode(self, text: str, **kwargs: bool) -> list[int]:
"""Encode text to list of character codes."""
del kwargs # unused
return [ord(c) for c in text]
class TestGetPrefixLength:
"""Tests for the core prefix matching algorithm."""
def test_identical_arrays(self) -> None:
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 5
def test_partial_match(self) -> None:
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 9, 9])
assert get_prefix_length(a, b) == 3
def test_no_match(self) -> None:
a = mx.array([1, 2, 3])
b = mx.array([9, 9, 9])
assert get_prefix_length(a, b) == 0
def test_different_lengths(self) -> None:
short = mx.array([1, 2, 3])
long = mx.array([1, 2, 3, 4, 5])
# Should return length of shorter when they match
assert get_prefix_length(short, long) == 3
assert get_prefix_length(long, short) == 3
def test_empty_array(self) -> None:
empty: mx.array = mx.array([])
tokens = mx.array([1, 2, 3])
assert get_prefix_length(empty, tokens) == 0
assert get_prefix_length(tokens, empty) == 0
class TestKVPrefixCache:
"""Tests for the KV prefix cache."""
@pytest.fixture
def tokenizer(self) -> MockTokenizer:
"""Mock tokenizer that converts string to list of char codes."""
return MockTokenizer()
@pytest.fixture
def fake_kv(self) -> list[KVCacheType]:
"""Fake KV cache for testing."""
return [object()]
def test_put_stores_entry(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=10)
cache.put(tokenizer, "hello", fake_kv)
assert len(cache) == 1
def test_put_same_prompt_twice_does_not_duplicate(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=10)
cache.put(tokenizer, "hello", fake_kv)
cache.put(tokenizer, "hello", fake_kv)
assert len(cache) == 1
def test_lru_eviction(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=2)
# Fill cache
cache.put(tokenizer, "first", fake_kv)
cache.put(tokenizer, "second", fake_kv)
assert len(cache) == 2
# Add third - should evict "first" (oldest)
cache.put(tokenizer, "third", fake_kv)
assert len(cache) == 2
# Add "first" again - if it was evicted, cache size stays 2
# If it wasn't evicted, this would be a no-op
cache.put(tokenizer, "first", fake_kv)
# Now add fourth - if "first" was re-added, size is still 2
cache.put(tokenizer, "fourth", fake_kv)
assert len(cache) == 2
def test_lru_access_refreshes_entry(
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
) -> None:
cache = KVPrefixCache(max_size=2)
# Add two entries
cache.put(tokenizer, "first", fake_kv)
cache.put(tokenizer, "second", fake_kv)
# Access "first" again (moves to end of LRU)
cache.put(tokenizer, "first", fake_kv)
# Add third - should evict "second" now (oldest)
cache.put(tokenizer, "third", fake_kv)
# Add "second" again - this will add it as new entry
cache.put(tokenizer, "second", fake_kv)
# Now "first" is oldest, adding fourth should evict it
cache.put(tokenizer, "fourth", fake_kv)
# Cache should have "second" and "fourth", not "first"
assert len(cache) == 2
def test_clear(self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]) -> None:
cache = KVPrefixCache()
cache.put(tokenizer, "hello", fake_kv)
cache.put(tokenizer, "world", fake_kv)
assert len(cache) == 2
cache.clear()
assert len(cache) == 0

View File

@@ -1,4 +1,5 @@
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
@@ -6,6 +7,8 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
@@ -15,8 +18,9 @@ async def check_reachability(
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
@@ -32,7 +36,16 @@ async def check_reachability(
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()

1493
uv.lock generated
View File

File diff suppressed because it is too large Load Diff