Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
1e29ff41ed fix: forcefully terminate runner on Shutdown timeout
When an instance is deleted while the runner is loading a model,
the runner can't process the Shutdown task because mx.eval() is
blocking. Previously the runner process was left running as an
orphan. Now we call supervisor.shutdown() to kill the process
and free all resources immediately.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 15:45:58 -08:00
7 changed files with 7 additions and 256 deletions

View File

@@ -1,151 +0,0 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
vocab_size: int
num_attention_heads: int
num_attention_groups: int
head_dim: int
intermediate_size: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
max_position_embeddings: int
sliding_window: int
layer_types: Optional[List[str]]
yarn_only_types: Optional[List[str]]
partial_rotary_factors: Optional[List[float]]
attention_other_setting: Optional[Dict[str, Any]]
use_head_wise_attn_gate: bool
moe_num_experts: int
moe_top_k: int
moe_intermediate_size: int
share_expert_dim: int
moe_layers_enum: Optional[str]
moe_router_scaling_factor: float
norm_expert_weight: bool
swiglu_limits: Optional[List[float]]
swiglu_limits_shared: Optional[List[float]]
tie_word_embeddings: bool
class Step3p5MLP(nn.Module):
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
limit: Optional[float]
def __init__(
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5MoEGate(nn.Module):
top_k: int
n_routed_experts: int
routed_scaling_factor: float
norm_topk_prob: bool
gate: nn.Linear
router_bias: mx.array
def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class Step3p5MoE(nn.Module):
gate: Step3p5MoEGate
switch_mlp: SwitchGLU
share_expert: Step3p5MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5Attention(nn.Module):
is_sliding: bool
num_heads: int
num_kv_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
q_norm: nn.Module
k_norm: nn.Module
use_head_wise_attn_gate: bool
g_proj: nn.Linear
rope: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5DecoderLayer(nn.Module):
self_attn: Step3p5Attention
is_sliding: bool
is_moe_layer: bool
mlp: Step3p5MLP | Step3p5MoE
input_layernorm: nn.Module
post_attention_layernorm: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5Model(nn.Module):
args: ModelArgs
vocab_size: int
num_layers: int
embed_tokens: nn.Embedding
layers: list[Step3p5DecoderLayer]
norm: nn.Module
_swa_idx: Optional[int]
_full_idx: Optional[int]
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
class Model(nn.Module):
args: ModelArgs
model_type: str
model: Step3p5Model
lm_head: nn.Linear
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
@property
def layers(self) -> list[Step3p5DecoderLayer]: ...
def make_cache(self) -> list[Any]: ...
@property
def cast_predicate(self) -> Any: ...
@property
def quant_predicate(self) -> Any: ...

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

View File

@@ -189,7 +189,6 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -35,9 +35,6 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
@@ -267,19 +264,6 @@ def pipeline_auto_parallel(
)
)
if isinstance(inner_model_instance, Step35InnerModel):
inner_model_instance.num_layers = len(layers)
sliding_layers = [
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
]
full_layers = [
i
for i, layer in enumerate(layers)
if not getattr(layer, "is_sliding", True)
]
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -443,14 +427,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step35Model):
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -1005,46 +981,3 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class Step35ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step35Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
layer.self_attn.num_kv_heads //= self.N
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
layer.self_attn.g_proj = self.all_to_sharded_linear(
layer.self_attn.g_proj
)
if isinstance(layer.mlp, Step35MLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
else:
layer.mlp.sharding_group = self.group
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
mx.eval(layer)
return model

View File

@@ -224,10 +224,16 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
supervisor = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await supervisor.start_task(task)
except TimeoutError:
logger.warning(
f"Runner {runner_id} did not respond to Shutdown within 3s, "
f"forcefully terminating process"
)
supervisor.shutdown()
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut