Compare commits

..

16 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
58e751a930 revert 2026-02-13 20:50:27 +00:00
Ryuichi Leo Takashige
6718da7af3 try optimisations 2026-02-13 19:54:23 +00:00
Ryuichi Leo Takashige
9d9237f68f delete unnecessary files 2026-02-13 19:54:08 +00:00
Ryuichi Leo Takashige
8de4e10736 fix depends for CacheList 2026-02-13 15:29:53 +00:00
Ryuichi Leo Takashige
0de3e486df update glm 5 to use upstream mlx lm 2026-02-13 12:50:08 +00:00
Ryuichi Leo Takashige
ce0eef999e return to mlx lm main 2026-02-13 12:31:07 +00:00
Ryuichi Leo Takashige
20fb6a9acc handle absolute paths 2026-02-13 11:09:46 +00:00
Ryuichi Leo Takashige
4a1234106b add type stub 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2929249147 fix glm eos id 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
837ffc6b97 dont patch glm5 tokenizer? 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2366ed0299 add glm5 model cards 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
c95c088952 convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2af1c81cde convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
6922dd4ead download faster 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
8c2fb7f130 Add tensor sharding 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
0488cb2967 update pyproject.toml 2026-02-12 23:46:13 +00:00
25 changed files with 161 additions and 383 deletions

View File

@@ -0,0 +1,46 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

View File

@@ -1,151 +0,0 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
vocab_size: int
num_attention_heads: int
num_attention_groups: int
head_dim: int
intermediate_size: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
max_position_embeddings: int
sliding_window: int
layer_types: Optional[List[str]]
yarn_only_types: Optional[List[str]]
partial_rotary_factors: Optional[List[float]]
attention_other_setting: Optional[Dict[str, Any]]
use_head_wise_attn_gate: bool
moe_num_experts: int
moe_top_k: int
moe_intermediate_size: int
share_expert_dim: int
moe_layers_enum: Optional[str]
moe_router_scaling_factor: float
norm_expert_weight: bool
swiglu_limits: Optional[List[float]]
swiglu_limits_shared: Optional[List[float]]
tie_word_embeddings: bool
class Step3p5MLP(nn.Module):
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
limit: Optional[float]
def __init__(
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5MoEGate(nn.Module):
top_k: int
n_routed_experts: int
routed_scaling_factor: float
norm_topk_prob: bool
gate: nn.Linear
router_bias: mx.array
def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class Step3p5MoE(nn.Module):
gate: Step3p5MoEGate
switch_mlp: SwitchGLU
share_expert: Step3p5MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5Attention(nn.Module):
is_sliding: bool
num_heads: int
num_kv_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
q_norm: nn.Module
k_norm: nn.Module
use_head_wise_attn_gate: bool
g_proj: nn.Linear
rope: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5DecoderLayer(nn.Module):
self_attn: Step3p5Attention
is_sliding: bool
is_moe_layer: bool
mlp: Step3p5MLP | Step3p5MoE
input_layernorm: nn.Module
post_attention_layernorm: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5Model(nn.Module):
args: ModelArgs
vocab_size: int
num_layers: int
embed_tokens: nn.Embedding
layers: list[Step3p5DecoderLayer]
norm: nn.Module
_swa_idx: Optional[int]
_full_idx: Optional[int]
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
class Model(nn.Module):
args: ModelArgs
model_type: str
model: Step3p5Model
lm_head: nn.Linear
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
@property
def layers(self) -> list[Step3p5DecoderLayer]: ...
def make_cache(self) -> list[Any]: ...
@property
def cast_predicate(self) -> Any: ...
@property
def quant_predicate(self) -> Any: ...

View File

@@ -806,7 +806,6 @@
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
{getModelFitStatus}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}

View File

@@ -72,8 +72,6 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
// Granular node state types from the new state structure

22
download_glm5_shard.sh Executable file
View File

@@ -0,0 +1,22 @@
#!/bin/bash
# Usage: ./download_glm5_shard.sh <start> <end> [local_dir]
#
# Split across 4 Macs:
# Mac 1: ./download_glm5_shard.sh 1 71
# Mac 2: ./download_glm5_shard.sh 72 141
# Mac 3: ./download_glm5_shard.sh 142 212
# Mac 4: ./download_glm5_shard.sh 213 282
set -euo pipefail
START=${1:?Usage: $0 <start> <end> [local_dir]}
END=${2:?Usage: $0 <start> <end> [local_dir]}
LOCAL_DIR="${3:-GLM-5}"
INCLUDES=()
for i in $(seq "$START" "$END"); do
INCLUDES+=(--include "$(printf 'model-%05d-of-00282.safetensors' "$i")")
done
echo "Downloading safetensors $START-$END to $LOCAL_DIR"
hf download zai-org/GLM-5 "${INCLUDES[@]}" --local-dir "$LOCAL_DIR"

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx==0.30.6",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
@@ -64,6 +64,8 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
#mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", marker = "sys_platform == 'darwin'" }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-8bit"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

View File

@@ -24,7 +24,6 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
TextGeneration,
@@ -36,7 +35,6 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -321,14 +319,6 @@ class Master:
chunk=chunk,
)
)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -153,8 +153,6 @@ def place_instance(
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -169,8 +167,6 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -12,7 +12,6 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -73,8 +72,6 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -193,25 +190,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -7,11 +7,17 @@ from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
def _resolve_env_path(env_value: str) -> Path:
"""Resolve an environment variable path: absolute paths are used as-is, relative paths are resolved from home."""
p = Path(env_value)
return p if p.is_absolute() else Path.home() / p
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV
return _resolve_env_path(_EXO_HOME_ENV)
if sys.platform != "linux":
return Path.home() / ".exo"
@@ -31,15 +37,19 @@ _EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None)
EXO_MODELS_DIR = (
EXO_DATA_HOME / "models"
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
else _resolve_env_path(_EXO_MODELS_DIR_ENV)
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
find_resources()
if _RESOURCES_DIR_ENV is None
else _resolve_env_path(_RESOURCES_DIR_ENV)
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
find_dashboard()
if _DASHBOARD_DIR_ENV is None
else _resolve_env_path(_DASHBOARD_DIR_ENV)
)
# Log files (data/logs or cache)

View File

@@ -182,6 +182,7 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],
@@ -189,7 +190,6 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -38,8 +38,6 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None
num_draft_tokens: int = 4
class CreateInstance(BaseCommand):
@@ -74,14 +72,6 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
@@ -99,7 +89,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
| SendInputChunk
)

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -68,14 +68,6 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -149,7 +141,6 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut

View File

@@ -40,12 +40,6 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -86,17 +80,9 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
@@ -104,5 +90,4 @@ Task = (
| ImageGeneration
| ImageEdits
| Shutdown
| SetDraftModel
)

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, ModelId, NodeId
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,8 +19,6 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -35,9 +35,6 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
@@ -163,11 +160,14 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
dep_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, output) # type: ignore[reportUnknownMemberType]
if self.is_prefill:
mx.eval(output)
if cache is not None:
mx.eval(cache.keys) # type: ignore
mx.eval(dep_cache.keys) # type: ignore
if not self.is_prefill:
output = mx.distributed.all_gather(output, group=self.group)[
@@ -267,19 +267,6 @@ def pipeline_auto_parallel(
)
)
if isinstance(inner_model_instance, Step35InnerModel):
inner_model_instance.num_layers = len(layers)
sliding_layers = [
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
]
full_layers = [
i
for i, layer in enumerate(layers)
if not getattr(layer, "is_sliding", True)
]
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -307,7 +294,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
last = cache[-1] # type: ignore
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
return logits
@@ -333,7 +322,9 @@ def patch_tensor_model[T](model: T) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
last = cache[-1] # pyright: ignore[reportAny]
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
@@ -443,14 +434,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step35Model):
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -547,11 +530,13 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
# Shard attention heads
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
@@ -561,10 +546,11 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.q_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
@@ -581,12 +567,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
@@ -1005,46 +996,3 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class Step35ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step35Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
layer.self_attn.num_kv_heads //= self.N
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
layer.self_attn.g_proj = self.all_to_sharded_linear(
layer.self_attn.g_proj
)
if isinstance(layer.mlp, Step35MLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
else:
layer.mlp.sharding_group = self.group
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
mx.eval(layer)
return model

View File

@@ -223,27 +223,6 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: ModelId) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
@@ -332,10 +311,12 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
# For GLM-5 and GLM-4.7
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
return None

View File

@@ -295,8 +295,8 @@ def main(
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
elif "glm-4" in shard_metadata.model_card.model_id.lower():
patch_glm4_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
@@ -863,7 +863,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
def patch_glm4_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""

12
uv.lock generated
View File

@@ -416,9 +416,9 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
{ name = "msgspec", specifier = ">=0.19.0" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
@@ -1098,8 +1098,8 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
version = "0.30.7"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#bcf630614ffb5624bcb19870a7bcb0d847e6e98f" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
@@ -1109,10 +1109,6 @@ dependencies = [
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"