mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 03:33:30 -05:00
Compare commits
15 Commits
rust-explo
...
leo/add-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7f65fa9a3 | ||
|
|
82c36f1b26 | ||
|
|
db8919a89a | ||
|
|
fe05608260 | ||
|
|
bce8ee3a6e | ||
|
|
741e2790dd | ||
|
|
1f29b9c85d | ||
|
|
bfa3160339 | ||
|
|
d745157342 | ||
|
|
e3751e03c2 | ||
|
|
cd9f3182d9 | ||
|
|
a54ba12dee | ||
|
|
d98f2c9b68 | ||
|
|
33f22ca78a | ||
|
|
b60a59bbf6 |
@@ -1139,7 +1139,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`flatten`."""
|
||||
|
||||
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
|
||||
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`reshape` but the shape can be passed either as a
|
||||
:obj:`tuple` or as separate arguments.
|
||||
@@ -1222,7 +1222,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`swapaxes`."""
|
||||
|
||||
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
|
||||
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`transpose` but the axes can be passed either as
|
||||
a tuple or as separate arguments.
|
||||
|
||||
@@ -30,6 +30,9 @@ class Conv1d(Module):
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
weight: mx.array
|
||||
groups: int
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
||||
@@ -11,7 +11,10 @@ import mlx.core as mx
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
|
||||
offset: int
|
||||
def update_and_fetch(
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
@@ -87,6 +90,7 @@ def create_attention_mask(
|
||||
class _BaseCache(Cache):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
|
||||
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Type stubs for mlx_lm.models.qwen3_next"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
class Qwen3NextMLP(nn.Module):
|
||||
gate_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextGatedDeltaNet(nn.Module):
|
||||
hidden_size: int
|
||||
num_v_heads: int
|
||||
num_k_heads: int
|
||||
head_k_dim: int
|
||||
head_v_dim: int
|
||||
key_dim: int
|
||||
value_dim: int
|
||||
conv_kernel_size: int
|
||||
conv_dim: int
|
||||
conv1d: nn.Conv1d
|
||||
in_proj_qkvz: nn.Linear
|
||||
in_proj_ba: nn.Linear
|
||||
dt_bias: mx.array
|
||||
A_log: mx.array
|
||||
out_proj: nn.Linear
|
||||
|
||||
def __init__(self, config: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
norm_topk_prob: bool
|
||||
num_experts: int
|
||||
top_k: int
|
||||
gate: nn.Linear
|
||||
switch_mlp: SwitchGLU
|
||||
shared_expert: Qwen3NextMLP
|
||||
shared_expert_gate: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextDecoderLayer(nn.Module):
|
||||
is_linear: bool
|
||||
linear_attn: Qwen3NextGatedDeltaNet
|
||||
self_attn: Qwen3NextAttention
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock
|
||||
|
||||
def __init__(self, args: Any, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextModel(nn.Module):
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Qwen3NextDecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: Qwen3NextModel
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[Qwen3NextDecoderLayer]: ...
|
||||
@@ -112,6 +112,10 @@ class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
think_start: str | None
|
||||
think_end: str | None
|
||||
think_start_id: int | None
|
||||
think_end_id: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
|
||||
@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
@@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import MiniMaxAttention
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
@@ -618,6 +626,84 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.group = group
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: "Cache | None" = None,
|
||||
) -> mx.array:
|
||||
batch_dim, seq_dim, _ = x.shape
|
||||
|
||||
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
|
||||
|
||||
queries: mx.array = self._original_layer.q_proj(x)
|
||||
keys: mx.array = self._original_layer.k_proj(x)
|
||||
values: mx.array = self._original_layer.v_proj(x)
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1]
|
||||
k_dim = keys.shape[-1]
|
||||
n = self.group.size()
|
||||
|
||||
qk = mx.concatenate(
|
||||
[queries, keys], axis=-1
|
||||
) # (batch_dim, seq_dim, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (n*batch_dim, seq_dim, q_dim + k_dim)
|
||||
|
||||
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * q_dim)
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * k_dim)
|
||||
|
||||
queries = self._original_layer.q_norm(queries)
|
||||
keys = self._original_layer.k_norm(keys)
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self._original_layer.rope(queries, offset=cache.offset)
|
||||
keys = self._original_layer.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self._original_layer.rope(queries)
|
||||
keys = self._original_layer.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self._original_layer.scale, # type: ignore
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
|
||||
|
||||
return self._original_layer.o_proj(output)
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -626,7 +712,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -637,18 +722,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -673,18 +751,95 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
if isinstance(layer, Qwen3DecoderLayer):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
linear_attn = layer.linear_attn
|
||||
|
||||
linear_attn.in_proj_qkvz = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_qkvz
|
||||
)
|
||||
linear_attn.in_proj_ba = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_ba
|
||||
)
|
||||
linear_attn.out_proj = self.sharded_to_all_linear(
|
||||
linear_attn.out_proj
|
||||
)
|
||||
|
||||
# Shard conv1d: depthwise conv with non-contiguous channel slicing.
|
||||
# Channel layout is [q(key_dim), k(key_dim), v(value_dim)].
|
||||
# Each rank takes its head-slice from each of the three sections.
|
||||
rank = self.group.rank()
|
||||
key_dim = linear_attn.key_dim
|
||||
value_dim = linear_attn.value_dim
|
||||
key_dim_shard = key_dim // self.N
|
||||
value_dim_shard = value_dim // self.N
|
||||
|
||||
q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)
|
||||
k_idx = mx.arange(
|
||||
key_dim + rank * key_dim_shard,
|
||||
key_dim + (rank + 1) * key_dim_shard,
|
||||
)
|
||||
v_idx = mx.arange(
|
||||
2 * key_dim + rank * value_dim_shard,
|
||||
2 * key_dim + (rank + 1) * value_dim_shard,
|
||||
)
|
||||
conv_indices = mx.concatenate([q_idx, k_idx, v_idx])
|
||||
linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]
|
||||
new_conv_dim = key_dim_shard * 2 + value_dim_shard
|
||||
linear_attn.conv1d.groups = new_conv_dim
|
||||
|
||||
num_v_shard = linear_attn.num_v_heads // self.N
|
||||
v_start = rank * num_v_shard
|
||||
v_end = v_start + num_v_shard
|
||||
linear_attn.A_log = linear_attn.A_log[v_start:v_end]
|
||||
linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]
|
||||
|
||||
linear_attn.num_k_heads //= self.N
|
||||
linear_attn.num_v_heads //= self.N
|
||||
linear_attn.key_dim = (
|
||||
linear_attn.head_k_dim * linear_attn.num_k_heads
|
||||
)
|
||||
linear_attn.value_dim = (
|
||||
linear_attn.head_v_dim * linear_attn.num_v_heads
|
||||
)
|
||||
linear_attn.conv_dim = (
|
||||
linear_attn.key_dim * 2 + linear_attn.value_dim
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
@@ -694,6 +849,14 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_expert.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_expert.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -26,12 +24,79 @@ _MEMORY_THRESHOLD = float(
|
||||
)
|
||||
|
||||
|
||||
class CacheLayerSnapshot:
|
||||
"""Snapshot of a single cache layer's state."""
|
||||
|
||||
|
||||
class ArraysCacheSnapshot(CacheLayerSnapshot):
|
||||
def __init__(self, state: list[object]):
|
||||
self.state = state
|
||||
|
||||
|
||||
class RotatingKVCacheSnapshot(CacheLayerSnapshot):
|
||||
def __init__(
|
||||
self, keys: mx.array | None, values: mx.array | None, offset: int, idx: int
|
||||
):
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
self.offset = offset
|
||||
self.idx = idx
|
||||
|
||||
|
||||
class CacheSnapshot:
|
||||
def __init__(self, states: list[CacheLayerSnapshot | None], token_count: int):
|
||||
self.states = states
|
||||
self.token_count = token_count
|
||||
|
||||
|
||||
def snapshot_cache_states(cache: KVCacheType) -> CacheSnapshot:
|
||||
states: list[CacheLayerSnapshot | None] = []
|
||||
for c in cache:
|
||||
if isinstance(c, ArraysCache):
|
||||
states.append(ArraysCacheSnapshot(deepcopy(c.state))) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
elif isinstance(c, RotatingKVCache):
|
||||
# Deep copy arrays to avoid mutation during later generation
|
||||
# (keys/values can be None for empty cache despite type annotation)
|
||||
keys_copy = mx.array(c.keys) if c.keys is not None else None # pyright: ignore[reportUnnecessaryComparison]
|
||||
values_copy = mx.array(c.values) if c.values is not None else None # pyright: ignore[reportUnnecessaryComparison]
|
||||
states.append(
|
||||
RotatingKVCacheSnapshot(
|
||||
keys=keys_copy,
|
||||
values=values_copy,
|
||||
offset=c.offset,
|
||||
idx=c._idx, # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownArgumentType]
|
||||
)
|
||||
)
|
||||
else:
|
||||
states.append(None)
|
||||
token_count = cache_length(cache)
|
||||
return CacheSnapshot(states=states, token_count=token_count)
|
||||
|
||||
|
||||
def _find_nearest_snapshot(
|
||||
snapshots: list[CacheSnapshot],
|
||||
target_token_count: int,
|
||||
) -> CacheSnapshot | None:
|
||||
best: CacheSnapshot | None = None
|
||||
for snap in snapshots:
|
||||
if snap.token_count <= target_token_count and (
|
||||
best is None or snap.token_count > best.token_count
|
||||
):
|
||||
best = snap
|
||||
return best
|
||||
|
||||
|
||||
def has_abnormal_kv_caches(cache: KVCacheType) -> bool:
|
||||
return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
@@ -41,36 +106,67 @@ class KVPrefixCache:
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
snapshots: list[CacheSnapshot] | None = None,
|
||||
):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.prompts.append(prompt_tokens)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._snapshots.append(snapshots)
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
snapshots: list[CacheSnapshot] | None,
|
||||
restore_pos: int,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
old_snapshots = self._snapshots[index]
|
||||
merged: list[CacheSnapshot] = []
|
||||
if old_snapshots:
|
||||
merged = [s for s in old_snapshots if s.token_count <= restore_pos]
|
||||
if snapshots:
|
||||
merged.extend(snapshots)
|
||||
|
||||
self.prompts[index] = prompt_tokens
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._snapshots[index] = merged or None
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
|
||||
|
||||
def _get_snapshot(
|
||||
self, entry_index: int, target_token_count: int
|
||||
) -> tuple[int, CacheSnapshot | None]:
|
||||
if not has_abnormal_kv_caches(self.caches[entry_index]):
|
||||
return target_token_count, None
|
||||
|
||||
snapshots = self._snapshots[entry_index]
|
||||
if not snapshots:
|
||||
return 0, None
|
||||
|
||||
snap = _find_nearest_snapshot(snapshots, target_token_count)
|
||||
if snap is not None:
|
||||
return snap.token_count, snap
|
||||
|
||||
return 0, None
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
@@ -79,76 +175,94 @@ class KVPrefixCache:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
|
||||
For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
|
||||
nearest SSM snapshot position at or before the match point for correctness.
|
||||
Same for RotatingKVCache.
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
max_length = len(prompt_tokens)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
best_index: int | None = None
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if best_index is None:
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
# For exact match we trim to max_length-1
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
# Snapshotable model with no usable snapshot — need fresh cache
|
||||
if (
|
||||
restore_pos == 0
|
||||
and restore_snap is None
|
||||
and has_abnormal_kv_caches(self.caches[best_index])
|
||||
):
|
||||
match_kind = (
|
||||
"exact match"
|
||||
if is_exact
|
||||
else f"prefix match at {best_length}/{max_length}"
|
||||
)
|
||||
logger.info(
|
||||
f"KV cache {match_kind} but no snapshot, "
|
||||
f"need to prefill {max_length} tokens"
|
||||
)
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
cached_length = cache_length(self.caches[best_index])
|
||||
tokens_to_trim = cached_length - restore_pos
|
||||
if tokens_to_trim > 0:
|
||||
trim_cache(prompt_cache, tokens_to_trim, restore_snap)
|
||||
|
||||
self._access_counter += 1
|
||||
self._last_used[best_index] = self._access_counter
|
||||
remaining = prompt_tokens[restore_pos:]
|
||||
|
||||
if is_exact:
|
||||
logger.info(
|
||||
f"KV cache exact match: {max_length} tokens "
|
||||
f"(reusing {restore_pos}, re-processing {len(remaining)})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_length}/{max_length} tokens "
|
||||
f"(restoring to {restore_pos}, need to prefill {len(remaining)})"
|
||||
)
|
||||
return prompt_cache, remaining, best_index
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
# Evict LRU entries until below threshold
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
len(self.caches) > 0
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._snapshots.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
@@ -169,6 +283,40 @@ class KVPrefixCache:
|
||||
return max_pressure
|
||||
|
||||
|
||||
def trim_cache(
|
||||
cache: KVCacheType,
|
||||
num_tokens: int,
|
||||
snapshot: CacheSnapshot | None = None,
|
||||
) -> None:
|
||||
for i, c in enumerate(cache):
|
||||
layer_snap = snapshot.states[i] if snapshot is not None else None
|
||||
|
||||
if isinstance(c, ArraysCache):
|
||||
if isinstance(layer_snap, ArraysCacheSnapshot):
|
||||
c.state = deepcopy(layer_snap.state)
|
||||
else:
|
||||
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
elif isinstance(c, RotatingKVCache):
|
||||
if isinstance(layer_snap, RotatingKVCacheSnapshot):
|
||||
c.keys = (
|
||||
mx.array(layer_snap.keys) if layer_snap.keys is not None else None
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
c.values = (
|
||||
mx.array(layer_snap.values)
|
||||
if layer_snap.values is not None
|
||||
else None
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
c.offset = layer_snap.offset
|
||||
c._idx = layer_snap.idx # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
c.keys = None # pyright: ignore[reportAttributeAccessIssue]
|
||||
c.values = None # pyright: ignore[reportAttributeAccessIssue]
|
||||
c.offset = 0
|
||||
c._idx = 0 # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
@@ -177,14 +325,14 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
@@ -215,7 +363,7 @@ def make_kv_cache(
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
if hasattr(model, "make_cache"):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -22,10 +23,20 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
ArraysCacheSnapshot,
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
RotatingKVCacheSnapshot,
|
||||
encode_prompt,
|
||||
has_abnormal_kv_caches,
|
||||
make_kv_cache,
|
||||
snapshot_cache_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -41,7 +52,8 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> tuple[float, int]:
|
||||
capture_snapshots: bool = False,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -52,17 +64,21 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0, 0
|
||||
return 0.0, 0, []
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
has_snapshotable = has_abnormal_kv_caches(cache)
|
||||
snapshots: list[CacheSnapshot] = []
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
if has_snapshotable:
|
||||
snapshots.append(snapshot_cache_states(cache))
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
@@ -79,7 +95,30 @@ def prefill(
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
pre_gen = snapshots[-1] if has_snapshotable else None
|
||||
|
||||
for i, c in enumerate(cache):
|
||||
layer_snap = pre_gen.states[i] if pre_gen is not None else None
|
||||
|
||||
if isinstance(c, ArraysCache):
|
||||
if isinstance(layer_snap, ArraysCacheSnapshot):
|
||||
c.state = deepcopy(layer_snap.state)
|
||||
elif isinstance(c, RotatingKVCache):
|
||||
if isinstance(layer_snap, RotatingKVCacheSnapshot):
|
||||
c.keys = (
|
||||
mx.array(layer_snap.keys) if layer_snap.keys is not None else None
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
c.values = (
|
||||
mx.array(layer_snap.values)
|
||||
if layer_snap.values is not None
|
||||
else None
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
c.offset = layer_snap.offset
|
||||
c._idx = layer_snap.idx # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
c.trim(1) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
@@ -87,12 +126,13 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec, num_tokens
|
||||
return tokens_per_sec, num_tokens, snapshots
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -130,9 +170,7 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated
|
||||
|
||||
@@ -161,12 +199,17 @@ def mlx_generate(
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
# Encode prompt once at the top and fix unmatched think tags
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
|
||||
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
is_bench = task.bench
|
||||
if is_bench:
|
||||
@@ -177,12 +220,11 @@ def mlx_generate(
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prompt_tokens = all_prompt_tokens
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, all_prompt_tokens
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
@@ -206,9 +248,16 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
capture_snapshots = has_abnormal_kv_caches(caches) and kv_prefix_cache is not None
|
||||
prefill_tps, prefill_tokens, cache_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
capture_snapshots=capture_snapshots,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = cache_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
@@ -296,16 +345,11 @@ def mlx_generate(
|
||||
),
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
# Log generation stats
|
||||
# Update prefix cache BEFORE yielding the final response.
|
||||
# Consumers typically break on finish_reason, which prevents
|
||||
# the generator from resuming — so any code after yield
|
||||
# would never execute.
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
generated_tokens = len(generated_text_parts)
|
||||
generation_tps = (
|
||||
@@ -317,18 +361,42 @@ def mlx_generate(
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
generated_tokens_array = mx.array(
|
||||
tokenizer.encode(
|
||||
"".join(generated_text_parts), add_special_tokens=False
|
||||
)
|
||||
)
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
full_prompt_tokens,
|
||||
caches,
|
||||
cache_snapshots,
|
||||
restore_pos=prefix_hit_length,
|
||||
)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
kv_prefix_cache.add_kv_cache(
|
||||
full_prompt_tokens, caches, cache_snapshots
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
mx_barrier(group)
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -464,6 +464,30 @@ def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> b
|
||||
return think_token is not None and prompt.rstrip().endswith(think_token)
|
||||
|
||||
|
||||
def fix_unmatched_think_end_tokens(
|
||||
tokens: mx.array, tokenizer: TokenizerWrapper
|
||||
) -> mx.array:
|
||||
if not tokenizer.has_thinking:
|
||||
return tokens
|
||||
assert tokenizer.think_start_id
|
||||
assert tokenizer.think_end_id
|
||||
think_start_id: int = tokenizer.think_start_id
|
||||
think_end_id: int = tokenizer.think_end_id
|
||||
token_list: list[int] = cast(list[int], tokens.tolist())
|
||||
result: list[int] = []
|
||||
depth = 0
|
||||
for token in token_list:
|
||||
if token == think_start_id:
|
||||
depth += 1
|
||||
elif token == think_end_id:
|
||||
if depth == 0:
|
||||
result.append(think_start_id)
|
||||
else:
|
||||
depth -= 1
|
||||
result.append(token)
|
||||
return mx.array(result)
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
"""
|
||||
A KVCache that pretends to exist but holds zero tokens.
|
||||
|
||||
@@ -204,6 +204,7 @@ def main(
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
@@ -250,6 +251,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
@@ -603,7 +605,7 @@ def parse_thinking_models(
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id, # type: ignore
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
@@ -162,7 +162,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -170,7 +170,7 @@ class TestKVPrefixCacheWithModel:
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
@@ -194,7 +194,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = TextGenerationTaskParams(
|
||||
@@ -212,7 +212,7 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_prompt
|
||||
model, long_tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
@@ -238,12 +238,12 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
@@ -276,12 +276,12 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
@@ -352,7 +352,7 @@ class TestKVPrefixCacheWithModel:
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, prompt_tokens
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
@@ -495,7 +495,7 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
@@ -505,19 +505,10 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
# Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)
|
||||
with patch(
|
||||
"exo.worker.engines.mlx.cache.get_memory_used_percentage",
|
||||
return_value=0.95,
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = TextGenerationTaskParams(
|
||||
@@ -529,14 +520,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)
|
||||
|
||||
Reference in New Issue
Block a user