Compare commits

...

6 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
b0e6a6ac08 fix gpt oss tool calling 2026-02-13 21:23:26 +00:00
Alex Cheema
1c3cc699d4 fix: add missing getModelFitStatus prop to Recent tab (#1470)
## Summary
- Clicking the **Recent** tab in the Model Picker crashed with
`TypeError: e.getModelFitStatus is not a function`
- The `ModelPickerGroup` component in the Recent tab section was missing
the `{getModelFitStatus}` prop, while all other tabs (e.g., the main
model list) passed it correctly
- Added the missing `{getModelFitStatus}` prop so the Recent tab renders
without errors, matching the behavior of the other tabs

## Test plan
- [ ] Open the dashboard and click **SELECT MODEL**
- [ ] Switch to the **Recent** tab — verify it renders without crashing
- [ ] Confirm model fit status indicators display correctly on recent
models
- [ ] Verify the other tabs (All, Favorites) still work as before

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:37:32 +00:00
rltakashige
5a28642790 Add support for Step 3.5 flash! (#1460)
## Motivation

Working version of #1366 

## Changes

Add Step 3.5 Flash

## Test Plan

### Manual Testing
Works!

### Automated Testing
Running two processes tensor/pipeline sharded gives same logits as
single process.
2026-02-13 12:10:18 +00:00
Alex Cheema
6950f94109 dashboard: show macOS version in debug mode (#1454)
## Motivation

When debugging cluster issues, it's useful to see which macOS version
each node is running — especially since version mismatches can cause
compatibility problems. The OS version data is already collected by the
identity gatherer but wasn't shown in the topology graph.

## Changes

- Added OS version label (e.g. "macOS 15.2") to each node in the
topology graph when debug mode is enabled
- Renders below the existing TB and RDMA debug labels using the same
styling conventions
- Sources data from the existing `nodeIdentities` store (no backend
changes needed)

## Why It Works

The `nodeIdentities` store already contains `osVersion` for each node.
We simply read it in the `TopologyGraph` component and append a text
label in the debug section, following the exact same pattern as the TB
and RDMA labels.

## Test Plan

### Manual Testing
<!-- Hardware: MacBook Pro -->
- Enable debug mode in the dashboard
- Verify OS version label appears below TB/RDMA labels on each node
- Verify label disappears when debug mode is disabled

### Automated Testing
- Dashboard build passes (`npm run build`)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-12 17:56:55 +00:00
Alex Cheema
d0c44273db feat: add enable_thinking toggle for thinking-capable models (#1457)
## Motivation

Fixes #1456. Models like DeepSeek V3.2, Qwen3, and GLM-4.7 always run in
thinking mode because their chat templates auto-inject `<think>`. Users
need a way to disable thinking for models that support both modes.

## Changes

**API**: Added `enable_thinking: bool | None` to `ChatCompletionRequest`
and `TextGenerationTaskParams`. Passed through the adapter to
`tokenizer.apply_chat_template()` as a kwarg (only when explicitly set,
so models without the template variable are unaffected).

**Dashboard**: Added a thinking toggle button in the chat input area.
Visible only when the selected model has both "text" and "thinking"
capabilities.

## Why It Works

Most thinking model chat templates (DeepSeek, Qwen3, GLM) accept
`enable_thinking` as a Jinja template variable. Passing
`enable_thinking=False` prevents the template from injecting `<think>`,
matching the vLLM convention.

## Test Plan

### Manual Testing
- `curl` with `"enable_thinking": false` against a thinking model — no
`<think>` in output
- Dashboard toggle visible for thinking models, hidden for text-only
models

### Automated Testing
- basedpyright: 0 errors
- ruff: clean
- pytest: 188 passed
- dashboard build: success

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 17:35:24 +00:00
Jake Hillion
cc33213842 bench: add --settle-timeout for cluster startup retry (#1449)
exo_bench.py fails if started too soon after a cluster starts because
the topology hasn't populated yet, resulting in no valid placements.

Extracted the preview-fetch-and-filter logic into a
`fetch_and_filter_placements` helper and added a retry loop with
exponential backoff (1s initial, 2x multiplier, 60s cap). The new
`--settle-timeout` flag controls how long to retry (default 0 = try
once, preserving existing behaviour). Each retry logs a warning
explaining the cluster may still be settling.

Test plan:
- Tested on several freshly started clusters. This used to fail a lot,
  now it succeeds.
2026-02-12 16:38:09 +00:00
17 changed files with 466 additions and 54 deletions

View File

@@ -0,0 +1,151 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
vocab_size: int
num_attention_heads: int
num_attention_groups: int
head_dim: int
intermediate_size: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
max_position_embeddings: int
sliding_window: int
layer_types: Optional[List[str]]
yarn_only_types: Optional[List[str]]
partial_rotary_factors: Optional[List[float]]
attention_other_setting: Optional[Dict[str, Any]]
use_head_wise_attn_gate: bool
moe_num_experts: int
moe_top_k: int
moe_intermediate_size: int
share_expert_dim: int
moe_layers_enum: Optional[str]
moe_router_scaling_factor: float
norm_expert_weight: bool
swiglu_limits: Optional[List[float]]
swiglu_limits_shared: Optional[List[float]]
tie_word_embeddings: bool
class Step3p5MLP(nn.Module):
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
limit: Optional[float]
def __init__(
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5MoEGate(nn.Module):
top_k: int
n_routed_experts: int
routed_scaling_factor: float
norm_topk_prob: bool
gate: nn.Linear
router_bias: mx.array
def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class Step3p5MoE(nn.Module):
gate: Step3p5MoEGate
switch_mlp: SwitchGLU
share_expert: Step3p5MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5Attention(nn.Module):
is_sliding: bool
num_heads: int
num_kv_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
q_norm: nn.Module
k_norm: nn.Module
use_head_wise_attn_gate: bool
g_proj: nn.Linear
rope: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5DecoderLayer(nn.Module):
self_attn: Step3p5Attention
is_sliding: bool
is_moe_layer: bool
mlp: Step3p5MLP | Step3p5MoE
input_layernorm: nn.Module
post_attention_layernorm: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5Model(nn.Module):
args: ModelArgs
vocab_size: int
num_layers: int
embed_tokens: nn.Embedding
layers: list[Step3p5DecoderLayer]
norm: nn.Module
_swa_idx: Optional[int]
_full_idx: Optional[int]
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
class Model(nn.Module):
args: ModelArgs
model_type: str
model: Step3p5Model
lm_head: nn.Linear
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
@property
def layers(self) -> list[Step3p5DecoderLayer]: ...
def make_cache(self) -> list[Any]: ...
@property
def cast_predicate(self) -> Any: ...
@property
def quant_predicate(self) -> Any: ...

View File

@@ -19,6 +19,11 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -388,6 +393,66 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
@@ -464,6 +529,12 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -487,11 +558,6 @@ def main() -> int:
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
tokenizer = load_tokenizer_for_bench(full_model_id)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
@@ -503,54 +569,20 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
selected = fetch_and_filter_placements(client, full_model_id, args)
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
if not selected and args.settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + args.settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")

View File

@@ -12,6 +12,8 @@
ttftMs,
tps,
totalTokens,
thinkingEnabled as thinkingEnabledStore,
setConversationThinking,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -25,6 +27,7 @@
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
modelCapabilities?: Record<string, string[]>;
}
let {
@@ -34,6 +37,7 @@
autofocus = true,
showModelSelector = false,
modelTasks = {},
modelCapabilities = {},
}: Props = $props();
let message = $state("");
@@ -41,6 +45,7 @@
let fileInputRef: HTMLInputElement | undefined = $state();
let uploadedFiles = $state<ChatUploadedFile[]>([]);
let isDragOver = $state(false);
const thinkingEnabled = $derived(thinkingEnabledStore());
let loading = $derived(isLoading());
const currentModel = $derived(selectedChatModel());
const instanceData = $derived(instances());
@@ -95,6 +100,12 @@
);
});
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(
currentModel !== null &&
modelSupportsOnlyImageEditing(currentModel) &&
@@ -282,7 +293,11 @@
// Use image generation for text-to-image models
generateImage(content);
} else {
sendMessage(content, files);
sendMessage(
content,
files,
modelSupportsThinking() ? thinkingEnabled : null,
);
}
// Refocus the textarea after sending
@@ -520,6 +535,35 @@
</div>
{/if}
</div>
<!-- Thinking toggle -->
{#if modelSupportsThinking()}
<button
type="button"
onclick={() => setConversationThinking(!thinkingEnabled)}
class="flex items-center gap-1.5 px-2 py-1 rounded text-xs font-mono tracking-wide transition-all duration-200 flex-shrink-0 cursor-pointer border {thinkingEnabled
? 'bg-exo-yellow/15 border-exo-yellow/40 text-exo-yellow'
: 'bg-exo-medium-gray/30 border-exo-medium-gray/50 text-exo-light-gray/60 hover:text-exo-light-gray'}"
title={thinkingEnabled
? "Thinking enabled — click to disable"
: "Thinking disabled — click to enable"}
>
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
>
<path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
<span>{thinkingEnabled ? "THINK" : "NO THINK"}</span>
</button>
{/if}
<!-- Performance stats -->
{#if currentTtft !== null || currentTps !== null}
<div class="flex items-center gap-4 text-xs font-mono flex-shrink-0">

View File

@@ -806,6 +806,7 @@
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
{getModelFitStatus}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}

View File

@@ -7,6 +7,7 @@
debugMode,
nodeThunderboltBridge,
nodeRdmaCtl,
nodeIdentities,
type NodeInfo,
} from "$lib/stores/app.svelte";
@@ -33,6 +34,7 @@
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
const rdmaCtlData = $derived(nodeRdmaCtl());
const identitiesData = $derived(nodeIdentities());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -1177,6 +1179,22 @@
.attr("font-size", debugFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(rdmaText);
debugLabelY += debugLineHeight;
}
const identity = identitiesData[nodeInfo.id];
if (identity?.osVersion) {
nodeG
.append("text")
.attr("x", nodeInfo.x)
.attr("y", debugLabelY)
.attr("text-anchor", "middle")
.attr("fill", "rgba(179,179,179,0.7)")
.attr("font-size", debugFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(
`macOS ${identity.osVersion}${identity.osBuildVersion ? ` (${identity.osBuildVersion})` : ""}`,
);
}
}
});

View File

@@ -296,6 +296,7 @@ export interface Conversation {
modelId: string | null;
sharding: string | null;
instanceType: string | null;
enableThinking: boolean | null;
}
const STORAGE_KEY = "exo-conversations";
@@ -605,6 +606,7 @@ class AppStore {
modelId: conversation.modelId ?? null,
sharding: conversation.sharding ?? null,
instanceType: conversation.instanceType ?? null,
enableThinking: conversation.enableThinking ?? null,
}));
}
} catch (error) {
@@ -794,6 +796,7 @@ class AppStore {
modelId: derivedModelId,
sharding: derivedSharding,
instanceType: derivedInstanceType,
enableThinking: null,
};
this.conversations.unshift(conversation);
@@ -819,6 +822,7 @@ class AppStore {
this.hasStartedChat = true;
this.isTopologyMinimized = true;
this.isSidebarOpen = true; // Auto-open sidebar when chatting
this.thinkingEnabled = conversation.enableThinking ?? true;
this.refreshConversationModelFromInstances();
return true;
@@ -1932,6 +1936,11 @@ class AppStore {
}
}
/**
* Whether thinking is enabled for the current conversation
*/
thinkingEnabled = $state(true);
/**
* Selected model for chat (can be set by the UI)
*/
@@ -2110,6 +2119,7 @@ class AppStore {
textContent?: string;
preview?: string;
}[],
enableThinking?: boolean | null,
): Promise<void> {
if ((!content.trim() && (!files || files.length === 0)) || this.isLoading)
return;
@@ -2257,6 +2267,9 @@ class AppStore {
stream: true,
logprobs: true,
top_logprobs: 5,
...(enableThinking != null && {
enable_thinking: enableThinking,
}),
}),
});
@@ -2915,6 +2928,18 @@ class AppStore {
);
}
/**
* Update the thinking preference for the active conversation
*/
setConversationThinking(enabled: boolean) {
this.thinkingEnabled = enabled;
const conv = this.getActiveConversation();
if (conv) {
conv.enableThinking = enabled;
this.saveConversationsToStorage();
}
}
/**
* Start a download on a specific node
*/
@@ -3028,6 +3053,7 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
export const lastUpdate = () => appStore.lastUpdate;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const thinkingEnabled = () => appStore.thinkingEnabled;
export const debugMode = () => appStore.getDebugMode();
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
@@ -3043,7 +3069,8 @@ export const sendMessage = (
textContent?: string;
preview?: string;
}[],
) => appStore.sendMessage(content, files);
enableThinking?: boolean | null,
) => appStore.sendMessage(content, files, enableThinking);
export const generateImage = (prompt: string, modelId?: string) =>
appStore.generateImage(prompt, modelId);
export const editImage = (
@@ -3086,6 +3113,8 @@ export const deleteAllConversations = () => appStore.deleteAllConversations();
export const renameConversation = (id: string, name: string) =>
appStore.renameConversation(id, name);
export const getActiveConversation = () => appStore.getActiveConversation();
export const setConversationThinking = (enabled: boolean) =>
appStore.setConversationThinking(enabled);
// Sidebar actions
export const isSidebarOpen = () => appStore.isSidebarOpen;

View File

@@ -190,6 +190,19 @@
return tasks;
});
const modelCapabilities = $derived(() => {
const caps: Record<string, string[]> = {};
for (const model of models) {
if (model.capabilities && model.capabilities.length > 0) {
caps[model.id] = model.capabilities;
if (model.hugging_face_id) {
caps[model.hugging_face_id] = model.capabilities;
}
}
}
return caps;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(
@@ -2270,6 +2283,7 @@
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
modelCapabilities={modelCapabilities()}
/>
</div>
</div>
@@ -3049,6 +3063,7 @@
placeholder="Ask anything"
showModelSelector={true}
modelTasks={modelTasks()}
modelCapabilities={modelCapabilities()}
/>
</div>
</div>

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

View File

@@ -79,6 +79,7 @@ def chat_request_to_text_generation(
seed=request.seed,
stream=request.stream,
tools=request.tools,
enable_thinking=request.enable_thinking,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,

View File

@@ -189,6 +189,7 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -199,6 +199,7 @@ class ChatCompletionRequest(BaseModel):
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
enable_thinking: bool | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None

View File

@@ -40,5 +40,6 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
stop: str | list[str] | None = None
seed: int | None = None
chat_template_messages: list[dict[str, Any]] | None = None
enable_thinking: bool | None = None
logprobs: bool = False
top_logprobs: int | None = None

View File

@@ -35,6 +35,9 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
@@ -264,6 +267,19 @@ def pipeline_auto_parallel(
)
)
if isinstance(inner_model_instance, Step35InnerModel):
inner_model_instance.num_layers = len(layers)
sliding_layers = [
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
]
full_layers = [
i
for i, layer in enumerate(layers)
if not getattr(layer, "is_sliding", True)
]
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -427,6 +443,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step35Model):
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -981,3 +1005,46 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class Step35ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step35Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
layer.self_attn.num_kv_heads //= self.N
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
layer.self_attn.g_proj = self.all_to_sharded_linear(
layer.self_attn.g_proj
)
if isinstance(layer.mlp, Step35MLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
else:
layer.mlp.sharding_group = self.group
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
mx.eval(layer)
return model

View File

@@ -316,6 +316,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
elif "gpt-oss" in model_id_lower:
return [200002, 200012]
return None
@@ -462,11 +464,19 @@ def apply_chat_template(
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
formatted_messages = formatted_messages[:-1]
extra_kwargs: dict[str, Any] = {}
if task_params.enable_thinking is not None:
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
# Jinja ignores unknown variables, so passing both is safe.
extra_kwargs["enable_thinking"] = task_params.enable_thinking
extra_kwargs["thinking"] = task_params.enable_thinking
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
**extra_kwargs,
)
if partial_assistant_content:

View File

@@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
@@ -568,7 +569,11 @@ def parse_gpt_oss(
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel