mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
30 Commits
fix-instan
...
leo/add-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a4cf61f9f | ||
|
|
756128e7e1 | ||
|
|
596c0dae60 | ||
|
|
00e35f6c88 | ||
|
|
492029c86a | ||
|
|
8c397b7341 | ||
|
|
4db51dc574 | ||
|
|
9be2743726 | ||
|
|
4d6510d6e2 | ||
|
|
f509af1b4c | ||
|
|
6e80f2187c | ||
|
|
719a41a8cf | ||
|
|
d33ca2f204 | ||
|
|
6c9fd8c71a | ||
|
|
d781b5024a | ||
|
|
1130b8aecf | ||
|
|
82c36f1b26 | ||
|
|
db8919a89a | ||
|
|
fe05608260 | ||
|
|
bce8ee3a6e | ||
|
|
741e2790dd | ||
|
|
1f29b9c85d | ||
|
|
bfa3160339 | ||
|
|
d745157342 | ||
|
|
e3751e03c2 | ||
|
|
cd9f3182d9 | ||
|
|
a54ba12dee | ||
|
|
d98f2c9b68 | ||
|
|
33f22ca78a | ||
|
|
b60a59bbf6 |
@@ -1139,7 +1139,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`flatten`."""
|
||||
|
||||
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
|
||||
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`reshape` but the shape can be passed either as a
|
||||
:obj:`tuple` or as separate arguments.
|
||||
@@ -1222,7 +1222,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`swapaxes`."""
|
||||
|
||||
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
|
||||
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`transpose` but the axes can be passed either as
|
||||
a tuple or as separate arguments.
|
||||
|
||||
@@ -30,6 +30,9 @@ class Conv1d(Module):
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
weight: mx.array
|
||||
groups: int
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
||||
@@ -11,7 +11,10 @@ import mlx.core as mx
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
|
||||
offset: int
|
||||
def update_and_fetch(
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
@@ -87,6 +90,7 @@ def create_attention_mask(
|
||||
class _BaseCache(Cache):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.mla import MultiLinear
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
@@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module):
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
# kv_b_proj: nn.Linear
|
||||
embed_q: MultiLinear
|
||||
unembed_out: MultiLinear
|
||||
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
|
||||
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Type stubs for mlx_lm.models.qwen3_next"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
class Qwen3NextMLP(nn.Module):
|
||||
gate_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextGatedDeltaNet(nn.Module):
|
||||
hidden_size: int
|
||||
num_v_heads: int
|
||||
num_k_heads: int
|
||||
head_k_dim: int
|
||||
head_v_dim: int
|
||||
key_dim: int
|
||||
value_dim: int
|
||||
conv_kernel_size: int
|
||||
conv_dim: int
|
||||
conv1d: nn.Conv1d
|
||||
in_proj_qkvz: nn.Linear
|
||||
in_proj_ba: nn.Linear
|
||||
dt_bias: mx.array
|
||||
A_log: mx.array
|
||||
out_proj: nn.Linear
|
||||
|
||||
def __init__(self, config: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
norm_topk_prob: bool
|
||||
num_experts: int
|
||||
top_k: int
|
||||
gate: nn.Linear
|
||||
switch_mlp: SwitchGLU
|
||||
shared_expert: Qwen3NextMLP
|
||||
shared_expert_gate: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextDecoderLayer(nn.Module):
|
||||
is_linear: bool
|
||||
linear_attn: Qwen3NextGatedDeltaNet
|
||||
self_attn: Qwen3NextAttention
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock
|
||||
|
||||
def __init__(self, args: Any, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextModel(nn.Module):
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Qwen3NextDecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: Qwen3NextModel
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[Qwen3NextDecoderLayer]: ...
|
||||
@@ -113,6 +113,10 @@ class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
think_start: str | None
|
||||
think_end: str | None
|
||||
think_start_id: int | None
|
||||
think_end_id: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.4"; in
|
||||
version = let v = "0.30.5"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx==0.30.5; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.5; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -63,7 +63,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
|
||||
@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
@@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import MiniMaxAttention
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
@@ -503,12 +511,21 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
@@ -624,6 +641,84 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.group = group
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: "Cache | None" = None,
|
||||
) -> mx.array:
|
||||
batch_dim, seq_dim, _ = x.shape
|
||||
|
||||
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
|
||||
|
||||
queries: mx.array = self._original_layer.q_proj(x)
|
||||
keys: mx.array = self._original_layer.k_proj(x)
|
||||
values: mx.array = self._original_layer.v_proj(x)
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1]
|
||||
k_dim = keys.shape[-1]
|
||||
n = self.group.size()
|
||||
|
||||
qk = mx.concatenate(
|
||||
[queries, keys], axis=-1
|
||||
) # (batch_dim, seq_dim, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (n*batch_dim, seq_dim, q_dim + k_dim)
|
||||
|
||||
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * q_dim)
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * k_dim)
|
||||
|
||||
queries = self._original_layer.q_norm(queries)
|
||||
keys = self._original_layer.k_norm(keys)
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self._original_layer.rope(queries, offset=cache.offset)
|
||||
keys = self._original_layer.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self._original_layer.rope(queries)
|
||||
keys = self._original_layer.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self._original_layer.scale, # type: ignore
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
|
||||
|
||||
return self._original_layer.o_proj(output)
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -632,7 +727,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -643,18 +737,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -679,18 +766,95 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
if isinstance(layer, Qwen3DecoderLayer):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
linear_attn = layer.linear_attn
|
||||
|
||||
linear_attn.in_proj_qkvz = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_qkvz
|
||||
)
|
||||
linear_attn.in_proj_ba = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_ba
|
||||
)
|
||||
linear_attn.out_proj = self.sharded_to_all_linear(
|
||||
linear_attn.out_proj
|
||||
)
|
||||
|
||||
# Shard conv1d: depthwise conv with non-contiguous channel slicing.
|
||||
# Channel layout is [q(key_dim), k(key_dim), v(value_dim)].
|
||||
# Each rank takes its head-slice from each of the three sections.
|
||||
rank = self.group.rank()
|
||||
key_dim = linear_attn.key_dim
|
||||
value_dim = linear_attn.value_dim
|
||||
key_dim_shard = key_dim // self.N
|
||||
value_dim_shard = value_dim // self.N
|
||||
|
||||
q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)
|
||||
k_idx = mx.arange(
|
||||
key_dim + rank * key_dim_shard,
|
||||
key_dim + (rank + 1) * key_dim_shard,
|
||||
)
|
||||
v_idx = mx.arange(
|
||||
2 * key_dim + rank * value_dim_shard,
|
||||
2 * key_dim + (rank + 1) * value_dim_shard,
|
||||
)
|
||||
conv_indices = mx.concatenate([q_idx, k_idx, v_idx])
|
||||
linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]
|
||||
new_conv_dim = key_dim_shard * 2 + value_dim_shard
|
||||
linear_attn.conv1d.groups = new_conv_dim
|
||||
|
||||
num_v_shard = linear_attn.num_v_heads // self.N
|
||||
v_start = rank * num_v_shard
|
||||
v_end = v_start + num_v_shard
|
||||
linear_attn.A_log = linear_attn.A_log[v_start:v_end]
|
||||
linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]
|
||||
|
||||
linear_attn.num_k_heads //= self.N
|
||||
linear_attn.num_v_heads //= self.N
|
||||
linear_attn.key_dim = (
|
||||
linear_attn.head_k_dim * linear_attn.num_k_heads
|
||||
)
|
||||
linear_attn.value_dim = (
|
||||
linear_attn.head_v_dim * linear_attn.num_v_heads
|
||||
)
|
||||
linear_attn.conv_dim = (
|
||||
linear_attn.key_dim * 2 + linear_attn.value_dim
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
@@ -700,6 +864,14 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_expert.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_expert.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -26,51 +24,119 @@ _MEMORY_THRESHOLD = float(
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
class CacheSnapshot:
|
||||
"""Snapshot of states at a known token position."""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
self, states: list[RotatingKVCache | ArraysCache | None], token_count: int
|
||||
):
|
||||
self.states = states
|
||||
self.token_count = token_count
|
||||
|
||||
|
||||
def snapshot_ssm_states(cache: KVCacheType) -> CacheSnapshot:
|
||||
states: list[ArraysCache | RotatingKVCache | None] = []
|
||||
for c in cache:
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
states.append(deepcopy(c))
|
||||
else:
|
||||
states.append(None)
|
||||
token_count = cache_length(cache)
|
||||
return CacheSnapshot(states=states, token_count=token_count)
|
||||
|
||||
|
||||
def _find_nearest_snapshot(
|
||||
snapshots: list[CacheSnapshot],
|
||||
target_token_count: int,
|
||||
) -> CacheSnapshot | None:
|
||||
best: CacheSnapshot | None = None
|
||||
for snap in snapshots:
|
||||
if snap.token_count <= target_token_count and (
|
||||
best is None or snap.token_count > best.token_count
|
||||
):
|
||||
best = snap
|
||||
return best
|
||||
|
||||
|
||||
def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
"""Check if a cache contains any ArraysCache (SSM) entries."""
|
||||
return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
ssm_snapshots: list[CacheSnapshot] | None = None,
|
||||
):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.prompts.append(prompt_tokens)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._snapshots.append(ssm_snapshots)
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
snapshots: list[CacheSnapshot] | None,
|
||||
restore_pos: int,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
old_snapshots = self._snapshots[index]
|
||||
merged: list[CacheSnapshot] = []
|
||||
if old_snapshots:
|
||||
merged = [s for s in old_snapshots if s.token_count <= restore_pos]
|
||||
if snapshots:
|
||||
merged.extend(snapshots)
|
||||
|
||||
self.prompts[index] = prompt_tokens
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._snapshots[index] = merged or None
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
|
||||
|
||||
def _get_snapshot(
|
||||
self, entry_index: int, target_token_count: int
|
||||
) -> tuple[int, CacheSnapshot | None]:
|
||||
if not has_non_kv_caches(self.caches[entry_index]):
|
||||
return target_token_count, None
|
||||
|
||||
snapshots = self._snapshots[entry_index]
|
||||
if not snapshots:
|
||||
return 0, None
|
||||
|
||||
snap = _find_nearest_snapshot(snapshots, target_token_count)
|
||||
if snap is not None:
|
||||
return snap.token_count, snap
|
||||
|
||||
return 0, None
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
@@ -79,76 +145,71 @@ class KVPrefixCache:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
|
||||
For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
|
||||
nearest SSM snapshot position at or before the match point for correctness.
|
||||
Same for rotating KV Cache.
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
max_length = len(prompt_tokens)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
best_index: int | None = None
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
cached_length = cache_length(self.caches[best_index])
|
||||
tokens_to_trim = cached_length - restore_pos
|
||||
if tokens_to_trim > 0:
|
||||
trim_cache(prompt_cache, tokens_to_trim, restore_snap)
|
||||
# Reset cache offset to match trimmed length
|
||||
for c in prompt_cache:
|
||||
if hasattr(c, "offset"):
|
||||
c.offset = restore_pos
|
||||
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
self._access_counter += 1
|
||||
self._last_used[best_index] = self._access_counter
|
||||
remaining = prompt_tokens[restore_pos:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
return prompt_cache, remaining, best_index
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
# Evict LRU entries until below threshold
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
len(self.caches) > 0
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._snapshots.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
@@ -169,6 +230,21 @@ class KVPrefixCache:
|
||||
return max_pressure
|
||||
|
||||
|
||||
def trim_cache(
|
||||
cache: KVCacheType,
|
||||
num_tokens: int,
|
||||
snapshot: CacheSnapshot | None = None,
|
||||
) -> None:
|
||||
for i, c in enumerate(cache):
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
if snapshot is not None and snapshot.states[i] is not None:
|
||||
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
|
||||
else:
|
||||
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
else:
|
||||
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
@@ -177,14 +253,14 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
@@ -214,8 +290,7 @@ def make_kv_cache(
|
||||
) -> KVCacheType:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
if hasattr(model, "make_cache"):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -23,7 +24,14 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
encode_prompt,
|
||||
has_non_kv_caches,
|
||||
make_kv_cache,
|
||||
snapshot_ssm_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
@@ -32,6 +40,7 @@ from exo.worker.engines.mlx.constants import (
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -47,7 +56,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> tuple[float, int]:
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -58,17 +67,21 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0, 0
|
||||
return 0.0, 0, []
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
has_ssm = has_non_kv_caches(cache)
|
||||
snapshots: list[CacheSnapshot] = []
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
@@ -85,7 +98,18 @@ def prefill(
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
|
||||
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
|
||||
for i, c in enumerate(cache):
|
||||
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
assert pre_gen is not None
|
||||
if pre_gen.states[i] is not None:
|
||||
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
@@ -93,12 +117,14 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec, num_tokens
|
||||
# Exclude the last snapshot
|
||||
return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else []
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -117,7 +143,7 @@ def warmup_inference(
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
sampler = make_sampler(temp=0.0)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
@@ -136,9 +162,7 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated
|
||||
|
||||
@@ -221,11 +245,17 @@ def mlx_generate(
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
|
||||
seed = task.seed or 42
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode prompt once at the top and fix unmatched think tags
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
|
||||
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
is_bench = task.bench
|
||||
@@ -237,13 +267,16 @@ def mlx_generate(
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prompt_tokens = all_prompt_tokens
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, all_prompt_tokens
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
if prefix_hit_length > 0:
|
||||
logger.info(
|
||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||
)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -266,12 +299,17 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
last_token = prompt_tokens[-2:]
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
@@ -299,7 +337,6 @@ def mlx_generate(
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
@@ -367,16 +404,6 @@ def mlx_generate(
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
@@ -390,18 +417,44 @@ def mlx_generate(
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
generated_tokens_array = mx.array(
|
||||
tokenizer.encode(
|
||||
"".join(generated_text_parts), add_special_tokens=False
|
||||
)
|
||||
)
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
full_prompt_tokens,
|
||||
caches,
|
||||
cache_snapshots,
|
||||
restore_pos=prefix_hit_length,
|
||||
)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
kv_prefix_cache.add_kv_cache(
|
||||
full_prompt_tokens, caches, cache_snapshots
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
mx_barrier(group)
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -490,6 +490,30 @@ def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> b
|
||||
return think_token is not None and prompt.rstrip().endswith(think_token)
|
||||
|
||||
|
||||
def fix_unmatched_think_end_tokens(
|
||||
tokens: mx.array, tokenizer: TokenizerWrapper
|
||||
) -> mx.array:
|
||||
if not tokenizer.has_thinking:
|
||||
return tokens
|
||||
assert tokenizer.think_start_id
|
||||
assert tokenizer.think_end_id
|
||||
think_start_id: int = tokenizer.think_start_id
|
||||
think_end_id: int = tokenizer.think_end_id
|
||||
token_list: list[int] = cast(list[int], tokens.tolist())
|
||||
result: list[int] = []
|
||||
depth = 0
|
||||
for token in token_list:
|
||||
if token == think_start_id:
|
||||
depth += 1
|
||||
elif token == think_end_id:
|
||||
if depth == 0:
|
||||
result.append(think_start_id)
|
||||
else:
|
||||
depth -= 1
|
||||
result.append(token)
|
||||
return mx.array(result)
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
"""
|
||||
A KVCache that pretends to exist but holds zero tokens.
|
||||
|
||||
@@ -193,7 +193,7 @@ def main(
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
@@ -226,6 +226,7 @@ def main(
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
@@ -274,6 +275,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
@@ -627,7 +629,7 @@ def parse_thinking_models(
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id, # type: ignore
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,10 +142,12 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert cache_length(cache) == len(tokens)
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
# Snapshots should be available for models with non-KV caches
|
||||
assert len(snapshots) > 0
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -159,10 +161,10 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -170,7 +172,7 @@ class TestKVPrefixCacheWithModel:
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
@@ -191,10 +193,12 @@ class TestKVPrefixCacheWithModel:
|
||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = TextGenerationTaskParams(
|
||||
@@ -212,13 +216,12 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_prompt
|
||||
model, long_tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# remaining_tokens should be the suffix after the shared prefix
|
||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
||||
# remaining_tokens covers from snapshot restore position to end
|
||||
assert len(remaining_tokens) >= len(long_tokens) - expected_prefix
|
||||
|
||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
||||
self, model_and_tokenizer
|
||||
@@ -235,15 +238,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
@@ -273,15 +276,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
@@ -298,7 +301,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -328,7 +331,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -352,20 +355,20 @@ class TestKVPrefixCacheWithModel:
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, prompt_tokens
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
assert matched_index == 0
|
||||
# Exact match: remaining_tokens is just the last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
||||
# Exact match: remaining_tokens is just the last token and the one before
|
||||
assert len(remaining_tokens) == 2
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-2:])
|
||||
|
||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -444,7 +447,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -481,7 +484,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -495,7 +498,7 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
@@ -505,19 +508,10 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
# Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)
|
||||
with patch(
|
||||
"exo.worker.engines.mlx.cache.get_memory_used_percentage",
|
||||
return_value=0.95,
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = TextGenerationTaskParams(
|
||||
@@ -529,14 +523,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)
|
||||
|
||||
@@ -34,6 +34,7 @@ TOKENIZER_FILE_PATTERNS = [
|
||||
"added_tokens.json",
|
||||
"tokenizer.model",
|
||||
"tokenization_*.py", # Custom tokenizer implementations
|
||||
"tool_declaration_ts.py", # Dependency of tokenization_kimi.py
|
||||
]
|
||||
|
||||
|
||||
|
||||
59
uv.lock
generated
59
uv.lock
generated
@@ -413,9 +413,9 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.4" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.4" },
|
||||
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.5" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.5" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
@@ -1020,22 +1020,22 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.4"
|
||||
version = "0.30.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/59/b6d138f5598bcd13d8e1d029a207cb8b18b14d5ded43533aef16d2e3852b/mlx-0.30.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e4b1ff6584ddcadcadbd7236f3ec6fe30abd918bcd75e51dd7693c113ab7d5f6", size = 572585 },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/57/72604531d02471c54dd1c71caeb77479297f37ab6aaa1125b457edfce9ee/mlx-0.30.4-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:1f367534078b10dcb660393a554f97732c194977ac8318bb389a76a6307757f8", size = 572587 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/5c/1a340ccc5051d222ceb58aa00c42ea5d11f4ae0bd0fc97673bef5d6ff24b/mlx-0.30.4-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:5344d195ac60dcdb871afb3ebb15c22112408f54c91ef507bd16e3928dfff38d", size = 572571 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/18/538c13fa6821459d8d2b6db1ac96f60679ef995f373c68be1d743055ba47/mlx-0.30.4-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:6879c7262c8f8f7a1a9ee6f27cbf5fe174d0863189a7672c9eb71cd8611bbaa7", size = 621260 },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/2c/e8aa0847ec97436443a78e87cc3fb95c94a2fe8b4b6ebb65cbaa67b6306c/mlx-0.30.4-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:367ba287ceb5b93a624b560ce8ce02378c03d1d60cc630b57efaf38061596d9b", size = 662522 },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/ab/d0a6303bf0f978e394036841089d58d2c8c305e3efbcce9e4351724b6f5c/mlx-0.30.4-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f026f3a30013e16034419caef0b0293ba84e69252fc1676d5d8becc92bb5a304", size = 574119 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/58/f5ac415a1781877b21e88f9257c7071e48ee91c34ca461e880b74677758a/mlx-0.30.4-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:96ad421cfe62a6fe7fc98521f8af9a530d7d7b6ded402ba6f4eb81a4a3087d1f", size = 574120 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/12/9eb62ebf0ca7989efa6dec92e79630ef70e54202b756523bdeadf3c009eb/mlx-0.30.4-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:dfafd24144d91f6b4bd5ef6711458c566fdf507aee6417567fc2da0469619878", size = 574112 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/f3/ada2b2126fc7a2634bd30c07418c6ae9657530d4534249c6949dbcc0013d/mlx-0.30.4-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:f016e16ff43dff6240ee91a8ba32226db1d55797a81a64d7af84e0e4409852ba", size = 622977 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/8d/fc498b847f9ed8459ee89fb5b06f7237541192a9e6cd965bed9f61114f5c/mlx-0.30.4-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:962f99d637a99058b7d7659b66570f988815f26f2ae9af52c4cd0359fab928e2", size = 662314 },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/bf/e399eac4371a9d80a4bff41e1500a914a852d026ef84c30c8e427964752f/mlx-0.30.5-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:1fd93ce3de56b01699dc18a9673c89510f1772125fc48f6385dca66be34500fc", size = 572736 },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/e3/ad176bc95791eb09dd4414811c8fc27741c3e4832824bf1f7881f8103928/mlx-0.30.5-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:8e24d40e8c5827a242d7b8ba2379c8ece5776b7163e05ceb7a110ec161582f91", size = 572738 },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/16/a1b4da06eec1efe8bb84ab450c790e7291a8610d619517965d71d31aa094/mlx-0.30.5-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:481555f8632a54a36405a80348d88631564cecf1c16c3a9a9bf971b11c353887", size = 572813 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/7c/97c1539a593393c65cf19a0ae89005fafa13321cfb185dbb409fe9b0a955/mlx-0.30.5-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:e89df90121910595529d37711a8a14f345e468b6e6caa453f40578acc90b7297", size = 621230 },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/ef/a266e0584ab5e301568eecf1b1e83ba92e87a395de6a7539404386ffd555/mlx-0.30.5-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:027412cf65319c6d09d726a4910795811c995adeecd4d19b859a5c4eab0b0cf2", size = 662868 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/7f/2089917d98a12528a43f50e0cae675432322a36683252451d06819dbf8b1/mlx-0.30.5-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:ce54786846716f43b18733af1ccbd7f4801d659e6191da9d8297e88950533533", size = 574246 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/09/7950b1587fab3fb4922371628d57b643297ec7e336dcfb44acf2738b087f/mlx-0.30.5-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:da405026dadbf7e5c838bdbf941484c307ad4629b8f137e52782c3093ef00fc9", size = 574245 },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/86/8592f35b2cb25077bb5e8bf0566b9bb2061b36baa774db3e5990e59848a4/mlx-0.30.5-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:d2cc2d131f65f83141ffeb858eda3ae4b694030b2e1d7ed393b3a8f5fa6944f8", size = 574262 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/c7/9393ae41bff83a69321aea8d8f481faf530f3b3f5912e364094ef2245e2b/mlx-0.30.5-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:3a58a108ed4cea1f644fcf0cf2a7fbaabcc43f3022da46b61cf50a7d80fe77f8", size = 623221 },
|
||||
{ url = "https://files.pythonhosted.org/packages/74/e9/ee0fa04cb294382f778c7800367e8bcceac1832d0e17554ab8cabd2fc48a/mlx-0.30.5-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:e623031b72e49b3c2ffbca203cc6a853f0c6c11afe271a9b410598d3fd9479f8", size = 662788 },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -1048,16 +1048,16 @@ cuda13 = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.4"
|
||||
version = "0.30.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/71/8a2f3598d628c6e5fd6ca4c58e080311dc39c558561d8f7fb2d91865f0e6/mlx_cpu-0.30.4-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:d3de6152e38f8a884d7cadb5e633bcf5fb346434867195709b4f6db8450e3f91", size = 8684835 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/c4/eae335cf6859c4a45be52888b754bfceb0ad5363bd05ae0ce3e67fac1dec/mlx_cpu-0.30.4-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:1b7f076587d1bd028a6f8197fe35721a39b8202e36b05e3aba89d29d79ab6764", size = 10257054 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/23/5837effd6be5ed322a456623a4ad51d7c569ee0234284d96c081f73f8bda/mlx_cpu-0.30.5-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:5e848db0830f70be3e706e465badf0570ebee3379470de9fb84c8d5b45b9e2e5", size = 8685370 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/0e/b5c98e9a2d767a967b5f0e464375bb4bbc4da2190a51586ef3bdca4e7875/mlx_cpu-0.30.5-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:3d32fef8958c3e1abff39e77309d715f131609ad63c8c3be0f54e2778fbd36ba", size = 10256264 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cuda-13"
|
||||
version = "0.30.4"
|
||||
version = "0.30.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas", marker = "sys_platform == 'linux'" },
|
||||
@@ -1066,14 +1066,14 @@ dependencies = [
|
||||
{ name = "nvidia-nccl-cu13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/79/e0aec1bf713eb6f6cbda69e1f4d145429e0477c3087aa41755078caafcb7/mlx_cuda_13-0.30.4-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:4edc42cb2e00a7e51621afd4c0b43154f7e2a15a3c6516878207cda9f85ee133", size = 66771153 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/ff/1793ec5ec7f486bc44356ac5d355a42577bf5c4c72c42feb9b237bc00838/mlx_cuda_13-0.30.4-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:23a4b617e8bcd5581e6d257ac09fee85f0195114523a863ba84118cfac4abb26", size = 69665061 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/cb/10622469930c6f4024cc9853be387158d9e39e0e19d97ee07120265e9ddc/mlx_cuda_13-0.30.5-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:6cd61592b4624ced451b312f1ce0cb6c951c438d9639c1b56dc43abfebb0fbaf", size = 66899934 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/9e/144f35c6b9ceaec2ddda184868e58e2a09188e2f82547dd1eabc74adddd4/mlx_cuda_13-0.30.5-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:b88db12c6eec02667902edc5a7ef3febac0822412a1477259f06abb564d09724", size = 69738714 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.5"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -1083,15 +1083,19 @@ dependencies = [
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.4"
|
||||
version = "0.30.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/95/b1/a50b84aaa76a60605606df49196456f31871148485ede7cbe3267a25a51e/mlx_metal-0.30.4-py3-none-macosx_14_0_arm64.whl", hash = "sha256:10c417f86778ac5529ecd2180f90de35f2d3a0fcad4d5176d211d651504c4922", size = 38260996 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/f0/6cce9e0ea545f61d0fa27dc6cd30ffa0e44f17bf859e5d75a34a9ba0da56/mlx_metal-0.30.4-py3-none-macosx_15_0_arm64.whl", hash = "sha256:f48f52490f0fcb2be924312d50c3a12625249d396a2a119ce4f7b0d388543ca9", size = 38255657 },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/fc/345f627bb88479cb53c3f37ad1947f865830060a3d792eec05954f53384d/mlx_metal-0.30.4-py3-none-macosx_26_0_arm64.whl", hash = "sha256:9a9fb6f9169eeb38a7f78389fe78306a1b5167fa489096bc50f9ca72074d7a95", size = 47541040 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/60/1bb7019c10645a3a7f3c5decea9241f5ff18133da96d396dabf27dd6f6c4/mlx_metal-0.30.5-py3-none-macosx_14_0_arm64.whl", hash = "sha256:f8287d0204170231aefda0b03acbb3b54336848be83e7916d7ef18d44e23593a", size = 38259369 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/4a/0b6a59629f9ee4de76c1b018a0c4d5a38844048f6ebccc2c72d2dba89bf0/mlx_metal-0.30.5-py3-none-macosx_15_0_arm64.whl", hash = "sha256:9ce0b169ec7043d46d63a7b04e048fdd51246a134c8ee7eccb9dfb8db9f3f468", size = 38254849 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/d4/038a7281a44366d963b30b3867423f59f65e3dae479cffe0e63855c600f3/mlx_metal-0.30.5-py3-none-macosx_26_0_arm64.whl", hash = "sha256:04cfe0bf2b95a294c5072e7084a708ca3e1ba9454775af17b1be2877a4606a9f", size = 47536555 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1503,6 +1507,9 @@ version = "11.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/93/0952f2ed8db3a5a4c7a11f91965d6184ebc8cd7cbb7941a260d5f018cd2d/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:1c627742b539bba4309df89171356fcb3cc5a9178355b2727d1b74a6cf155fbd", size = 2128328 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/e8/100c3d114b1a0bf4042f27e0f87d2f25e857e838034e98ca98fe7b8c0a9c/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30b7c02f3899d10f13d7a48163c8969e4e653f8b43416d23d13d1bbfdc93b9f8", size = 2170652 },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/86/3f758a28a6e381758545f7cdb4942e1cb79abd271bea932998fc0db93cb6/pillow-11.3.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:7859a4cc7c9295f5838015d8cc0a9c215b77e43d07a25e460f35cf516df8626f", size = 2227443 },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/f4/91d5b3ffa718df2f53b0dc109877993e511f4fd055d7e9508682e8aba092/pillow-11.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec1ee50470b0d050984394423d96325b744d55c701a439d2bd66089bff963d3c", size = 5278474 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/0e/37d7d3eca6c879fbd9dba21268427dffda1ab00d4eb05b32923d4fbe3b12/pillow-11.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7db51d222548ccfd274e4572fdbf3e810a5e66b00608862f947b163e613b67dd", size = 4686038 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/b0/3426e5c7f6565e752d81221af9d3676fdbb4f352317ceafd42899aaf5d8a/pillow-11.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2d6fcc902a24ac74495df63faad1884282239265c6839a0a6416d33faedfae7e", size = 5864407 },
|
||||
|
||||
Reference in New Issue
Block a user