Compare commits

...

6 Commits

Author SHA1 Message Date
Alex Cheema
8c57df8b37 fix: enable psutil fallback for memory monitoring when macmon is missing on macOS
On Darwin, the psutil memory poller was disabled (memory_poll_rate=None),
relying entirely on macmon. When macmon is not installed, no memory data
was reported, causing nodes to show zero memory in the cluster state and
blocking shard placement.

Now falls back to psutil-based memory polling when macmon is not found.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 08:34:54 -08:00
Alex Cheema
1c3cc699d4 fix: add missing getModelFitStatus prop to Recent tab (#1470)
## Summary
- Clicking the **Recent** tab in the Model Picker crashed with
`TypeError: e.getModelFitStatus is not a function`
- The `ModelPickerGroup` component in the Recent tab section was missing
the `{getModelFitStatus}` prop, while all other tabs (e.g., the main
model list) passed it correctly
- Added the missing `{getModelFitStatus}` prop so the Recent tab renders
without errors, matching the behavior of the other tabs

## Test plan
- [ ] Open the dashboard and click **SELECT MODEL**
- [ ] Switch to the **Recent** tab — verify it renders without crashing
- [ ] Confirm model fit status indicators display correctly on recent
models
- [ ] Verify the other tabs (All, Favorites) still work as before

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:37:32 +00:00
rltakashige
5a28642790 Add support for Step 3.5 flash! (#1460)
## Motivation

Working version of #1366 

## Changes

Add Step 3.5 Flash

## Test Plan

### Manual Testing
Works!

### Automated Testing
Running two processes tensor/pipeline sharded gives same logits as
single process.
2026-02-13 12:10:18 +00:00
Alex Cheema
6950f94109 dashboard: show macOS version in debug mode (#1454)
## Motivation

When debugging cluster issues, it's useful to see which macOS version
each node is running — especially since version mismatches can cause
compatibility problems. The OS version data is already collected by the
identity gatherer but wasn't shown in the topology graph.

## Changes

- Added OS version label (e.g. "macOS 15.2") to each node in the
topology graph when debug mode is enabled
- Renders below the existing TB and RDMA debug labels using the same
styling conventions
- Sources data from the existing `nodeIdentities` store (no backend
changes needed)

## Why It Works

The `nodeIdentities` store already contains `osVersion` for each node.
We simply read it in the `TopologyGraph` component and append a text
label in the debug section, following the exact same pattern as the TB
and RDMA labels.

## Test Plan

### Manual Testing
<!-- Hardware: MacBook Pro -->
- Enable debug mode in the dashboard
- Verify OS version label appears below TB/RDMA labels on each node
- Verify label disappears when debug mode is disabled

### Automated Testing
- Dashboard build passes (`npm run build`)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-12 17:56:55 +00:00
Alex Cheema
d0c44273db feat: add enable_thinking toggle for thinking-capable models (#1457)
## Motivation

Fixes #1456. Models like DeepSeek V3.2, Qwen3, and GLM-4.7 always run in
thinking mode because their chat templates auto-inject `<think>`. Users
need a way to disable thinking for models that support both modes.

## Changes

**API**: Added `enable_thinking: bool | None` to `ChatCompletionRequest`
and `TextGenerationTaskParams`. Passed through the adapter to
`tokenizer.apply_chat_template()` as a kwarg (only when explicitly set,
so models without the template variable are unaffected).

**Dashboard**: Added a thinking toggle button in the chat input area.
Visible only when the selected model has both "text" and "thinking"
capabilities.

## Why It Works

Most thinking model chat templates (DeepSeek, Qwen3, GLM) accept
`enable_thinking` as a Jinja template variable. Passing
`enable_thinking=False` prevents the template from injecting `<think>`,
matching the vLLM convention.

## Test Plan

### Manual Testing
- `curl` with `"enable_thinking": false` against a thinking model — no
`<think>` in output
- Dashboard toggle visible for thinking models, hidden for text-only
models

### Automated Testing
- basedpyright: 0 errors
- ruff: clean
- pytest: 188 passed
- dashboard build: success

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 17:35:24 +00:00
Jake Hillion
cc33213842 bench: add --settle-timeout for cluster startup retry (#1449)
exo_bench.py fails if started too soon after a cluster starts because
the topology hasn't populated yet, resulting in no valid placements.

Extracted the preview-fetch-and-filter logic into a
`fetch_and_filter_placements` helper and added a retry loop with
exponential backoff (1s initial, 2x multiplier, 60s cap). The new
`--settle-timeout` flag controls how long to retry (default 0 = try
once, preserving existing behaviour). Each retry logs a warning
explaining the cluster may still be settling.

Test plan:
- Tested on several freshly started clusters. This used to fail a lot,
  now it succeeds.
2026-02-12 16:38:09 +00:00
17 changed files with 462 additions and 53 deletions

View File

@@ -0,0 +1,151 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
vocab_size: int
num_attention_heads: int
num_attention_groups: int
head_dim: int
intermediate_size: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
max_position_embeddings: int
sliding_window: int
layer_types: Optional[List[str]]
yarn_only_types: Optional[List[str]]
partial_rotary_factors: Optional[List[float]]
attention_other_setting: Optional[Dict[str, Any]]
use_head_wise_attn_gate: bool
moe_num_experts: int
moe_top_k: int
moe_intermediate_size: int
share_expert_dim: int
moe_layers_enum: Optional[str]
moe_router_scaling_factor: float
norm_expert_weight: bool
swiglu_limits: Optional[List[float]]
swiglu_limits_shared: Optional[List[float]]
tie_word_embeddings: bool
class Step3p5MLP(nn.Module):
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
limit: Optional[float]
def __init__(
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5MoEGate(nn.Module):
top_k: int
n_routed_experts: int
routed_scaling_factor: float
norm_topk_prob: bool
gate: nn.Linear
router_bias: mx.array
def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class Step3p5MoE(nn.Module):
gate: Step3p5MoEGate
switch_mlp: SwitchGLU
share_expert: Step3p5MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5Attention(nn.Module):
is_sliding: bool
num_heads: int
num_kv_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
q_norm: nn.Module
k_norm: nn.Module
use_head_wise_attn_gate: bool
g_proj: nn.Linear
rope: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5DecoderLayer(nn.Module):
self_attn: Step3p5Attention
is_sliding: bool
is_moe_layer: bool
mlp: Step3p5MLP | Step3p5MoE
input_layernorm: nn.Module
post_attention_layernorm: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5Model(nn.Module):
args: ModelArgs
vocab_size: int
num_layers: int
embed_tokens: nn.Embedding
layers: list[Step3p5DecoderLayer]
norm: nn.Module
_swa_idx: Optional[int]
_full_idx: Optional[int]
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
class Model(nn.Module):
args: ModelArgs
model_type: str
model: Step3p5Model
lm_head: nn.Linear
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
@property
def layers(self) -> list[Step3p5DecoderLayer]: ...
def make_cache(self) -> list[Any]: ...
@property
def cast_predicate(self) -> Any: ...
@property
def quant_predicate(self) -> Any: ...

View File

@@ -19,6 +19,11 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -388,6 +393,66 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
@@ -464,6 +529,12 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -487,11 +558,6 @@ def main() -> int:
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
tokenizer = load_tokenizer_for_bench(full_model_id)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
@@ -503,54 +569,20 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
selected = fetch_and_filter_placements(client, full_model_id, args)
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
if not selected and args.settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + args.settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")

View File

@@ -12,6 +12,8 @@
ttftMs,
tps,
totalTokens,
thinkingEnabled as thinkingEnabledStore,
setConversationThinking,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -25,6 +27,7 @@
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
modelCapabilities?: Record<string, string[]>;
}
let {
@@ -34,6 +37,7 @@
autofocus = true,
showModelSelector = false,
modelTasks = {},
modelCapabilities = {},
}: Props = $props();
let message = $state("");
@@ -41,6 +45,7 @@
let fileInputRef: HTMLInputElement | undefined = $state();
let uploadedFiles = $state<ChatUploadedFile[]>([]);
let isDragOver = $state(false);
const thinkingEnabled = $derived(thinkingEnabledStore());
let loading = $derived(isLoading());
const currentModel = $derived(selectedChatModel());
const instanceData = $derived(instances());
@@ -95,6 +100,12 @@
);
});
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(
currentModel !== null &&
modelSupportsOnlyImageEditing(currentModel) &&
@@ -282,7 +293,11 @@
// Use image generation for text-to-image models
generateImage(content);
} else {
sendMessage(content, files);
sendMessage(
content,
files,
modelSupportsThinking() ? thinkingEnabled : null,
);
}
// Refocus the textarea after sending
@@ -520,6 +535,35 @@
</div>
{/if}
</div>
<!-- Thinking toggle -->
{#if modelSupportsThinking()}
<button
type="button"
onclick={() => setConversationThinking(!thinkingEnabled)}
class="flex items-center gap-1.5 px-2 py-1 rounded text-xs font-mono tracking-wide transition-all duration-200 flex-shrink-0 cursor-pointer border {thinkingEnabled
? 'bg-exo-yellow/15 border-exo-yellow/40 text-exo-yellow'
: 'bg-exo-medium-gray/30 border-exo-medium-gray/50 text-exo-light-gray/60 hover:text-exo-light-gray'}"
title={thinkingEnabled
? "Thinking enabled — click to disable"
: "Thinking disabled — click to enable"}
>
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
>
<path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
<span>{thinkingEnabled ? "THINK" : "NO THINK"}</span>
</button>
{/if}
<!-- Performance stats -->
{#if currentTtft !== null || currentTps !== null}
<div class="flex items-center gap-4 text-xs font-mono flex-shrink-0">

View File

@@ -806,6 +806,7 @@
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
{getModelFitStatus}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}

View File

@@ -7,6 +7,7 @@
debugMode,
nodeThunderboltBridge,
nodeRdmaCtl,
nodeIdentities,
type NodeInfo,
} from "$lib/stores/app.svelte";
@@ -33,6 +34,7 @@
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
const rdmaCtlData = $derived(nodeRdmaCtl());
const identitiesData = $derived(nodeIdentities());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -1177,6 +1179,22 @@
.attr("font-size", debugFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(rdmaText);
debugLabelY += debugLineHeight;
}
const identity = identitiesData[nodeInfo.id];
if (identity?.osVersion) {
nodeG
.append("text")
.attr("x", nodeInfo.x)
.attr("y", debugLabelY)
.attr("text-anchor", "middle")
.attr("fill", "rgba(179,179,179,0.7)")
.attr("font-size", debugFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(
`macOS ${identity.osVersion}${identity.osBuildVersion ? ` (${identity.osBuildVersion})` : ""}`,
);
}
}
});

View File

@@ -296,6 +296,7 @@ export interface Conversation {
modelId: string | null;
sharding: string | null;
instanceType: string | null;
enableThinking: boolean | null;
}
const STORAGE_KEY = "exo-conversations";
@@ -605,6 +606,7 @@ class AppStore {
modelId: conversation.modelId ?? null,
sharding: conversation.sharding ?? null,
instanceType: conversation.instanceType ?? null,
enableThinking: conversation.enableThinking ?? null,
}));
}
} catch (error) {
@@ -794,6 +796,7 @@ class AppStore {
modelId: derivedModelId,
sharding: derivedSharding,
instanceType: derivedInstanceType,
enableThinking: null,
};
this.conversations.unshift(conversation);
@@ -819,6 +822,7 @@ class AppStore {
this.hasStartedChat = true;
this.isTopologyMinimized = true;
this.isSidebarOpen = true; // Auto-open sidebar when chatting
this.thinkingEnabled = conversation.enableThinking ?? true;
this.refreshConversationModelFromInstances();
return true;
@@ -1932,6 +1936,11 @@ class AppStore {
}
}
/**
* Whether thinking is enabled for the current conversation
*/
thinkingEnabled = $state(true);
/**
* Selected model for chat (can be set by the UI)
*/
@@ -2110,6 +2119,7 @@ class AppStore {
textContent?: string;
preview?: string;
}[],
enableThinking?: boolean | null,
): Promise<void> {
if ((!content.trim() && (!files || files.length === 0)) || this.isLoading)
return;
@@ -2257,6 +2267,9 @@ class AppStore {
stream: true,
logprobs: true,
top_logprobs: 5,
...(enableThinking != null && {
enable_thinking: enableThinking,
}),
}),
});
@@ -2915,6 +2928,18 @@ class AppStore {
);
}
/**
* Update the thinking preference for the active conversation
*/
setConversationThinking(enabled: boolean) {
this.thinkingEnabled = enabled;
const conv = this.getActiveConversation();
if (conv) {
conv.enableThinking = enabled;
this.saveConversationsToStorage();
}
}
/**
* Start a download on a specific node
*/
@@ -3028,6 +3053,7 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
export const lastUpdate = () => appStore.lastUpdate;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const thinkingEnabled = () => appStore.thinkingEnabled;
export const debugMode = () => appStore.getDebugMode();
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
@@ -3043,7 +3069,8 @@ export const sendMessage = (
textContent?: string;
preview?: string;
}[],
) => appStore.sendMessage(content, files);
enableThinking?: boolean | null,
) => appStore.sendMessage(content, files, enableThinking);
export const generateImage = (prompt: string, modelId?: string) =>
appStore.generateImage(prompt, modelId);
export const editImage = (
@@ -3086,6 +3113,8 @@ export const deleteAllConversations = () => appStore.deleteAllConversations();
export const renameConversation = (id: string, name: string) =>
appStore.renameConversation(id, name);
export const getActiveConversation = () => appStore.getActiveConversation();
export const setConversationThinking = (enabled: boolean) =>
appStore.setConversationThinking(enabled);
// Sidebar actions
export const isSidebarOpen = () => appStore.isSidebarOpen;

View File

@@ -190,6 +190,19 @@
return tasks;
});
const modelCapabilities = $derived(() => {
const caps: Record<string, string[]> = {};
for (const model of models) {
if (model.capabilities && model.capabilities.length > 0) {
caps[model.id] = model.capabilities;
if (model.hugging_face_id) {
caps[model.hugging_face_id] = model.capabilities;
}
}
}
return caps;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(
@@ -2270,6 +2283,7 @@
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
modelCapabilities={modelCapabilities()}
/>
</div>
</div>
@@ -3049,6 +3063,7 @@
placeholder="Ask anything"
showModelSelector={true}
modelTasks={modelTasks()}
modelCapabilities={modelCapabilities()}
/>
</div>
</div>

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

View File

@@ -79,6 +79,7 @@ def chat_request_to_text_generation(
seed=request.seed,
stream=request.stream,
tools=request.tools,
enable_thinking=request.enable_thinking,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,

View File

@@ -189,6 +189,7 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -199,6 +199,7 @@ class ChatCompletionRequest(BaseModel):
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
enable_thinking: bool | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None

View File

@@ -40,5 +40,6 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
stop: str | list[str] | None = None
seed: int | None = None
chat_template_messages: list[dict[str, Any]] | None = None
enable_thinking: bool | None = None
logprobs: bool = False
top_logprobs: int | None = None

View File

@@ -388,6 +388,10 @@ class InfoGatherer:
if IS_DARWIN:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
else:
# macmon not installed — fall back to psutil for memory
logger.warning("macmon not found, falling back to psutil for memory monitoring")
self.memory_poll_rate = 1
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._monitor_thunderbolt_bridge_status)
tg.start_soon(self._monitor_rdma_ctl_status)

View File

@@ -35,6 +35,9 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
@@ -264,6 +267,19 @@ def pipeline_auto_parallel(
)
)
if isinstance(inner_model_instance, Step35InnerModel):
inner_model_instance.num_layers = len(layers)
sliding_layers = [
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
]
full_layers = [
i
for i, layer in enumerate(layers)
if not getattr(layer, "is_sliding", True)
]
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -427,6 +443,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step35Model):
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -981,3 +1005,46 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class Step35ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step35Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
layer.self_attn.num_kv_heads //= self.N
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
layer.self_attn.g_proj = self.all_to_sharded_linear(
layer.self_attn.g_proj
)
if isinstance(layer.mlp, Step35MLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
else:
layer.mlp.sharding_group = self.group
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
mx.eval(layer)
return model

View File

@@ -462,11 +462,19 @@ def apply_chat_template(
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
formatted_messages = formatted_messages[:-1]
extra_kwargs: dict[str, Any] = {}
if task_params.enable_thinking is not None:
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
# Jinja ignores unknown variables, so passing both is safe.
extra_kwargs["enable_thinking"] = task_params.enable_thinking
extra_kwargs["thinking"] = task_params.enable_thinking
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
**extra_kwargs,
)
if partial_assistant_content: