mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-14 08:04:15 -05:00
Compare commits
21 Commits
runner-can
...
e2e-tests
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bf4d1f585 | ||
|
|
5e27e4e719 | ||
|
|
b249757116 | ||
|
|
5c0b769bf8 | ||
|
|
702886d147 | ||
|
|
2526b7d166 | ||
|
|
ffb79d88ca | ||
|
|
4f32b9f180 | ||
|
|
1c3cc699d4 | ||
|
|
5a28642790 | ||
|
|
e8203596ab | ||
|
|
b88749a6c5 | ||
|
|
4a446b2779 | ||
|
|
a82feed8e3 | ||
|
|
da6e626f6f | ||
|
|
6950f94109 | ||
|
|
cf23916b8b | ||
|
|
d0c44273db | ||
|
|
80b29ba0d9 | ||
|
|
b6214c297f | ||
|
|
cc33213842 |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@@ -0,0 +1,15 @@
|
||||
.venv/
|
||||
.direnv/
|
||||
target/
|
||||
.git/
|
||||
.idea/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
dashboard/node_modules/
|
||||
dashboard/.svelte-kit/
|
||||
dashboard/build/
|
||||
dist/
|
||||
*.pdb
|
||||
**/__pycache__
|
||||
**/.DS_Store
|
||||
.mlx_typings/
|
||||
42
.github/workflows/e2e.yml
vendored
Normal file
42
.github/workflows/e2e.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: e2e-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches:
|
||||
- staging
|
||||
- main
|
||||
|
||||
jobs:
|
||||
e2e:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
|
||||
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
|
||||
/opt/microsoft /opt/az
|
||||
docker system prune -af
|
||||
df -h /
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build E2E image with cache
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: e2e/Dockerfile
|
||||
tags: exo-e2e:latest
|
||||
load: true
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Run E2E tests
|
||||
run: python3 e2e/run_all.py
|
||||
151
.mlx_typings/mlx_lm/models/step3p5.pyi
Normal file
151
.mlx_typings/mlx_lm/models/step3p5.pyi
Normal file
@@ -0,0 +1,151 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
vocab_size: int
|
||||
num_attention_heads: int
|
||||
num_attention_groups: int
|
||||
head_dim: int
|
||||
intermediate_size: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
max_position_embeddings: int
|
||||
sliding_window: int
|
||||
layer_types: Optional[List[str]]
|
||||
yarn_only_types: Optional[List[str]]
|
||||
partial_rotary_factors: Optional[List[float]]
|
||||
attention_other_setting: Optional[Dict[str, Any]]
|
||||
use_head_wise_attn_gate: bool
|
||||
moe_num_experts: int
|
||||
moe_top_k: int
|
||||
moe_intermediate_size: int
|
||||
share_expert_dim: int
|
||||
moe_layers_enum: Optional[str]
|
||||
moe_router_scaling_factor: float
|
||||
norm_expert_weight: bool
|
||||
swiglu_limits: Optional[List[float]]
|
||||
swiglu_limits_shared: Optional[List[float]]
|
||||
tie_word_embeddings: bool
|
||||
|
||||
class Step3p5MLP(nn.Module):
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
limit: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Step3p5MoEGate(nn.Module):
|
||||
top_k: int
|
||||
n_routed_experts: int
|
||||
routed_scaling_factor: float
|
||||
norm_topk_prob: bool
|
||||
gate: nn.Linear
|
||||
router_bias: mx.array
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class Step3p5MoE(nn.Module):
|
||||
gate: Step3p5MoEGate
|
||||
switch_mlp: SwitchGLU
|
||||
share_expert: Step3p5MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Step3p5Attention(nn.Module):
|
||||
is_sliding: bool
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
q_norm: nn.Module
|
||||
k_norm: nn.Module
|
||||
use_head_wise_attn_gate: bool
|
||||
g_proj: nn.Linear
|
||||
rope: nn.Module
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Step3p5DecoderLayer(nn.Module):
|
||||
self_attn: Step3p5Attention
|
||||
is_sliding: bool
|
||||
is_moe_layer: bool
|
||||
mlp: Step3p5MLP | Step3p5MoE
|
||||
input_layernorm: nn.Module
|
||||
post_attention_layernorm: nn.Module
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Step3p5Model(nn.Module):
|
||||
args: ModelArgs
|
||||
vocab_size: int
|
||||
num_layers: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Step3p5DecoderLayer]
|
||||
norm: nn.Module
|
||||
_swa_idx: Optional[int]
|
||||
_full_idx: Optional[int]
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[List[Any]] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
args: ModelArgs
|
||||
model_type: str
|
||||
model: Step3p5Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Any]] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
|
||||
@property
|
||||
def layers(self) -> list[Step3p5DecoderLayer]: ...
|
||||
def make_cache(self) -> list[Any]: ...
|
||||
@property
|
||||
def cast_predicate(self) -> Any: ...
|
||||
@property
|
||||
def quant_predicate(self) -> Any: ...
|
||||
@@ -19,6 +19,11 @@ from urllib.parse import urlencode
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -388,6 +393,66 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
@@ -464,6 +529,12 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -487,11 +558,6 @@ def main() -> int:
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = load_tokenizer_for_bench(full_model_id)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
@@ -503,54 +569,20 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
if not selected and args.settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
|
||||
1
conftest.py
Normal file
1
conftest.py
Normal file
@@ -0,0 +1 @@
|
||||
collect_ignore = ["tests/start_distributed_test.py"]
|
||||
@@ -12,6 +12,8 @@
|
||||
ttftMs,
|
||||
tps,
|
||||
totalTokens,
|
||||
thinkingEnabled as thinkingEnabledStore,
|
||||
setConversationThinking,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -25,6 +27,7 @@
|
||||
autofocus?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
modelTasks?: Record<string, string[]>;
|
||||
modelCapabilities?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -34,6 +37,7 @@
|
||||
autofocus = true,
|
||||
showModelSelector = false,
|
||||
modelTasks = {},
|
||||
modelCapabilities = {},
|
||||
}: Props = $props();
|
||||
|
||||
let message = $state("");
|
||||
@@ -41,6 +45,7 @@
|
||||
let fileInputRef: HTMLInputElement | undefined = $state();
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
let isDragOver = $state(false);
|
||||
const thinkingEnabled = $derived(thinkingEnabledStore());
|
||||
let loading = $derived(isLoading());
|
||||
const currentModel = $derived(selectedChatModel());
|
||||
const instanceData = $derived(instances());
|
||||
@@ -95,6 +100,12 @@
|
||||
);
|
||||
});
|
||||
|
||||
const modelSupportsThinking = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
const caps = modelCapabilities[currentModel] || [];
|
||||
return caps.includes("thinking") && caps.includes("text");
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
currentModel !== null &&
|
||||
modelSupportsOnlyImageEditing(currentModel) &&
|
||||
@@ -282,7 +293,11 @@
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
sendMessage(content, files);
|
||||
sendMessage(
|
||||
content,
|
||||
files,
|
||||
modelSupportsThinking() ? thinkingEnabled : null,
|
||||
);
|
||||
}
|
||||
|
||||
// Refocus the textarea after sending
|
||||
@@ -520,6 +535,35 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Thinking toggle -->
|
||||
{#if modelSupportsThinking()}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => setConversationThinking(!thinkingEnabled)}
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded text-xs font-mono tracking-wide transition-all duration-200 flex-shrink-0 cursor-pointer border {thinkingEnabled
|
||||
? 'bg-exo-yellow/15 border-exo-yellow/40 text-exo-yellow'
|
||||
: 'bg-exo-medium-gray/30 border-exo-medium-gray/50 text-exo-light-gray/60 hover:text-exo-light-gray'}"
|
||||
title={thinkingEnabled
|
||||
? "Thinking enabled — click to disable"
|
||||
: "Thinking disabled — click to enable"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="1.5"
|
||||
>
|
||||
<path
|
||||
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
<span>{thinkingEnabled ? "THINK" : "NO THINK"}</span>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Performance stats -->
|
||||
{#if currentTtft !== null || currentTps !== null}
|
||||
<div class="flex items-center gap-4 text-xs font-mono flex-shrink-0">
|
||||
|
||||
@@ -806,6 +806,7 @@
|
||||
isFavorite={favorites.has(group.id)}
|
||||
{selectedModelId}
|
||||
{canModelFit}
|
||||
{getModelFitStatus}
|
||||
onToggleExpand={() => toggleGroupExpanded(group.id)}
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
debugMode,
|
||||
nodeThunderboltBridge,
|
||||
nodeRdmaCtl,
|
||||
nodeIdentities,
|
||||
type NodeInfo,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
@@ -33,6 +34,7 @@
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
const rdmaCtlData = $derived(nodeRdmaCtl());
|
||||
const identitiesData = $derived(nodeIdentities());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -1177,6 +1179,22 @@
|
||||
.attr("font-size", debugFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(rdmaText);
|
||||
debugLabelY += debugLineHeight;
|
||||
}
|
||||
|
||||
const identity = identitiesData[nodeInfo.id];
|
||||
if (identity?.osVersion) {
|
||||
nodeG
|
||||
.append("text")
|
||||
.attr("x", nodeInfo.x)
|
||||
.attr("y", debugLabelY)
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("fill", "rgba(179,179,179,0.7)")
|
||||
.attr("font-size", debugFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(
|
||||
`macOS ${identity.osVersion}${identity.osBuildVersion ? ` (${identity.osBuildVersion})` : ""}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -296,6 +296,7 @@ export interface Conversation {
|
||||
modelId: string | null;
|
||||
sharding: string | null;
|
||||
instanceType: string | null;
|
||||
enableThinking: boolean | null;
|
||||
}
|
||||
|
||||
const STORAGE_KEY = "exo-conversations";
|
||||
@@ -605,6 +606,7 @@ class AppStore {
|
||||
modelId: conversation.modelId ?? null,
|
||||
sharding: conversation.sharding ?? null,
|
||||
instanceType: conversation.instanceType ?? null,
|
||||
enableThinking: conversation.enableThinking ?? null,
|
||||
}));
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -794,6 +796,7 @@ class AppStore {
|
||||
modelId: derivedModelId,
|
||||
sharding: derivedSharding,
|
||||
instanceType: derivedInstanceType,
|
||||
enableThinking: null,
|
||||
};
|
||||
|
||||
this.conversations.unshift(conversation);
|
||||
@@ -819,6 +822,7 @@ class AppStore {
|
||||
this.hasStartedChat = true;
|
||||
this.isTopologyMinimized = true;
|
||||
this.isSidebarOpen = true; // Auto-open sidebar when chatting
|
||||
this.thinkingEnabled = conversation.enableThinking ?? true;
|
||||
this.refreshConversationModelFromInstances();
|
||||
|
||||
return true;
|
||||
@@ -1932,6 +1936,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether thinking is enabled for the current conversation
|
||||
*/
|
||||
thinkingEnabled = $state(true);
|
||||
|
||||
/**
|
||||
* Selected model for chat (can be set by the UI)
|
||||
*/
|
||||
@@ -2110,6 +2119,7 @@ class AppStore {
|
||||
textContent?: string;
|
||||
preview?: string;
|
||||
}[],
|
||||
enableThinking?: boolean | null,
|
||||
): Promise<void> {
|
||||
if ((!content.trim() && (!files || files.length === 0)) || this.isLoading)
|
||||
return;
|
||||
@@ -2257,6 +2267,9 @@ class AppStore {
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
...(enableThinking != null && {
|
||||
enable_thinking: enableThinking,
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -2915,6 +2928,18 @@ class AppStore {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the thinking preference for the active conversation
|
||||
*/
|
||||
setConversationThinking(enabled: boolean) {
|
||||
this.thinkingEnabled = enabled;
|
||||
const conv = this.getActiveConversation();
|
||||
if (conv) {
|
||||
conv.enableThinking = enabled;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a download on a specific node
|
||||
*/
|
||||
@@ -3028,6 +3053,7 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
|
||||
export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const thinkingEnabled = () => appStore.thinkingEnabled;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
@@ -3043,7 +3069,8 @@ export const sendMessage = (
|
||||
textContent?: string;
|
||||
preview?: string;
|
||||
}[],
|
||||
) => appStore.sendMessage(content, files);
|
||||
enableThinking?: boolean | null,
|
||||
) => appStore.sendMessage(content, files, enableThinking);
|
||||
export const generateImage = (prompt: string, modelId?: string) =>
|
||||
appStore.generateImage(prompt, modelId);
|
||||
export const editImage = (
|
||||
@@ -3086,6 +3113,8 @@ export const deleteAllConversations = () => appStore.deleteAllConversations();
|
||||
export const renameConversation = (id: string, name: string) =>
|
||||
appStore.renameConversation(id, name);
|
||||
export const getActiveConversation = () => appStore.getActiveConversation();
|
||||
export const setConversationThinking = (enabled: boolean) =>
|
||||
appStore.setConversationThinking(enabled);
|
||||
|
||||
// Sidebar actions
|
||||
export const isSidebarOpen = () => appStore.isSidebarOpen;
|
||||
|
||||
@@ -190,6 +190,19 @@
|
||||
return tasks;
|
||||
});
|
||||
|
||||
const modelCapabilities = $derived(() => {
|
||||
const caps: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.capabilities && model.capabilities.length > 0) {
|
||||
caps[model.id] = model.capabilities;
|
||||
if (model.hugging_face_id) {
|
||||
caps[model.hugging_face_id] = model.capabilities;
|
||||
}
|
||||
}
|
||||
}
|
||||
return caps;
|
||||
});
|
||||
|
||||
// Helper to check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const model = models.find(
|
||||
@@ -2270,6 +2283,7 @@
|
||||
showHelperText={false}
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
modelCapabilities={modelCapabilities()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -3049,6 +3063,7 @@
|
||||
placeholder="Ask anything"
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
modelCapabilities={modelCapabilities()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
58
e2e/Dockerfile
Normal file
58
e2e/Dockerfile
Normal file
@@ -0,0 +1,58 @@
|
||||
# Stage 1: Build the dashboard
|
||||
FROM node:22-slim AS dashboard
|
||||
WORKDIR /app/dashboard
|
||||
COPY dashboard/package.json dashboard/package-lock.json ./
|
||||
RUN npm ci
|
||||
COPY dashboard/ .
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Build and run exo
|
||||
FROM python:3.13-slim
|
||||
|
||||
# Install system dependencies
|
||||
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
libblas-dev \
|
||||
liblapack-dev \
|
||||
liblapacke-dev \
|
||||
curl \
|
||||
protobuf-compiler \
|
||||
iptables \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust nightly
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
|
||||
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
|
||||
# Must be done BEFORE uv sync so any source builds also get the fix
|
||||
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
|
||||
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
|
||||
chmod +x /usr/bin/g++
|
||||
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files first for better layer caching
|
||||
COPY pyproject.toml Cargo.toml uv.lock README.md ./
|
||||
COPY rust/ ./rust/
|
||||
COPY bench/pyproject.toml ./bench/pyproject.toml
|
||||
|
||||
# Copy source and resources
|
||||
COPY src/ ./src/
|
||||
COPY resources/ ./resources/
|
||||
|
||||
# Copy built dashboard from stage 1
|
||||
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
|
||||
|
||||
# Install Python deps and build Rust bindings, then clean up build artifacts
|
||||
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
|
||||
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
|
||||
|
||||
CMD [".venv/bin/exo", "-v"]
|
||||
195
e2e/conftest.py
Normal file
195
e2e/conftest.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Shared E2E test infrastructure for exo cluster tests."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
|
||||
|
||||
|
||||
class Cluster:
|
||||
"""Async wrapper around a docker compose exo cluster."""
|
||||
|
||||
def __init__(self, name: str, overrides: list[str] | None = None):
|
||||
self.name = name
|
||||
self.project = f"e2e-{name}"
|
||||
compose_files = [str(E2E_DIR / "docker-compose.yml")]
|
||||
for path in overrides or []:
|
||||
compose_files.append(str(E2E_DIR / path))
|
||||
self._compose_base = [
|
||||
"docker",
|
||||
"compose",
|
||||
"-p",
|
||||
self.project,
|
||||
*[arg for f in compose_files for arg in ("-f", f)],
|
||||
]
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
await self.stop()
|
||||
|
||||
async def _run(self, *args: str, check: bool = True) -> str:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
print(output, file=sys.stderr)
|
||||
raise RuntimeError(
|
||||
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
|
||||
)
|
||||
return output
|
||||
|
||||
async def build(self):
|
||||
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"docker",
|
||||
"image",
|
||||
"inspect",
|
||||
"exo-e2e:latest",
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
)
|
||||
await proc.wait()
|
||||
if proc.returncode == 0:
|
||||
print(" Using pre-built image (exo-e2e:latest)")
|
||||
return
|
||||
print(" Building images...")
|
||||
await self._run("build", "--quiet")
|
||||
|
||||
async def start(self):
|
||||
print(" Starting cluster...")
|
||||
await self._run("up", "-d")
|
||||
|
||||
async def stop(self):
|
||||
print(" Cleaning up...")
|
||||
await self._run("down", "--timeout", "5", check=False)
|
||||
|
||||
async def logs(self) -> str:
|
||||
return await self._run("logs", check=False)
|
||||
|
||||
async def exec(
|
||||
self, service: str, *cmd: str, check: bool = True
|
||||
) -> tuple[int, str]:
|
||||
"""Run a command inside a running container. Returns (returncode, output)."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
"exec",
|
||||
"-T",
|
||||
service,
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
|
||||
)
|
||||
return proc.returncode, output
|
||||
|
||||
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
|
||||
"""Poll check_fn every 2s until it returns True or timeout expires."""
|
||||
print(f" Waiting for {description}...")
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if await check_fn():
|
||||
print(f" {description}")
|
||||
return
|
||||
await asyncio.sleep(2)
|
||||
output = await self.logs()
|
||||
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
|
||||
raise TimeoutError(f"Timed out waiting for {description}")
|
||||
|
||||
async def assert_healthy(self):
|
||||
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
|
||||
|
||||
async def both_nodes_started():
|
||||
log = await self.logs()
|
||||
return log.count("Starting node") >= 2
|
||||
|
||||
async def nodes_discovered():
|
||||
log = await self.logs()
|
||||
return log.count("ConnectionMessageType.Connected") >= 2
|
||||
|
||||
async def master_elected():
|
||||
log = await self.logs()
|
||||
return "demoting self" in log
|
||||
|
||||
async def api_responding():
|
||||
try:
|
||||
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
|
||||
return resp.status == 200
|
||||
except (URLError, OSError):
|
||||
return False
|
||||
|
||||
await self.wait_for("Both nodes started", both_nodes_started)
|
||||
await self.wait_for("Nodes discovered each other", nodes_discovered)
|
||||
await self.wait_for("Master election resolved", master_elected)
|
||||
await self.wait_for("API responding", api_responding)
|
||||
|
||||
async def _api(
|
||||
self, method: str, path: str, body: dict | None = None, timeout: int = 30
|
||||
) -> dict:
|
||||
"""Make an API request to the cluster. Returns parsed JSON."""
|
||||
url = f"http://localhost:52415{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = Request(
|
||||
url, data=data, headers={"Content-Type": "application/json"}, method=method
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda: urlopen(req, timeout=timeout).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
|
||||
async def place_model(self, model: str, timeout: int = 600):
|
||||
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
|
||||
await self._api("POST", "/place_instance", {"model_id": model})
|
||||
|
||||
async def model_ready():
|
||||
try:
|
||||
resp = await self._api("GET", "/v1/models")
|
||||
return any(m.get("id") == model for m in resp.get("data", []))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
|
||||
) -> dict:
|
||||
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
|
||||
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
last_error = None
|
||||
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
req = Request(
|
||||
"http://localhost:52415/v1/chat/completions",
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda r=req: urlopen(r, timeout=300).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await asyncio.sleep(5)
|
||||
|
||||
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")
|
||||
20
e2e/docker-compose.yml
Normal file
20
e2e/docker-compose.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
services:
|
||||
exo-node-1:
|
||||
image: exo-e2e:latest
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
ports:
|
||||
- "52415:52415"
|
||||
|
||||
exo-node-2:
|
||||
image: exo-e2e:latest
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
77
e2e/run_all.py
Normal file
77
e2e/run_all.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discovers and runs all E2E tests in e2e/test_*.py.
|
||||
|
||||
Tests with '# slow' on the first line of their docstring are skipped
|
||||
unless --slow is passed or E2E_SLOW=1 is set.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
|
||||
|
||||
def is_slow(test_file: Path) -> bool:
|
||||
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
|
||||
with open(test_file) as f:
|
||||
for line in f:
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
if line.strip().startswith('"""') or line.strip().startswith("'''"):
|
||||
# Read into the docstring
|
||||
for doc_line in f:
|
||||
if "slow" in doc_line.lower() and doc_line.strip().startswith(
|
||||
"slow"
|
||||
):
|
||||
return True
|
||||
if '"""' in doc_line or "'''" in doc_line:
|
||||
break
|
||||
break
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
|
||||
if "--update-snapshots" in sys.argv:
|
||||
os.environ["UPDATE_SNAPSHOTS"] = "1"
|
||||
test_files = sorted(E2E_DIR.glob("test_*.py"))
|
||||
if not test_files:
|
||||
print("No test files found")
|
||||
sys.exit(1)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
failures = []
|
||||
|
||||
for test_file in test_files:
|
||||
name = test_file.stem
|
||||
if is_slow(test_file) and not run_slow:
|
||||
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
print(f"=== {name} ===")
|
||||
result = subprocess.run([sys.executable, str(test_file)])
|
||||
if result.returncode == 0:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
failures.append(name)
|
||||
print()
|
||||
|
||||
total = passed + failed + skipped
|
||||
print("================================")
|
||||
print(
|
||||
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
|
||||
)
|
||||
|
||||
if failed:
|
||||
print(f"Failed: {' '.join(failures)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
69
e2e/snapshot.py
Normal file
69
e2e/snapshot.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Snapshot testing infrastructure for E2E tests.
|
||||
|
||||
Provides deterministic regression testing by comparing inference output
|
||||
against saved snapshots. On first run, snapshots are created automatically.
|
||||
Set UPDATE_SNAPSHOTS=1 to regenerate snapshots when output intentionally changes.
|
||||
|
||||
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
|
||||
since floating-point results differ between CPU architectures.
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
ARCH = platform.machine()
|
||||
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
|
||||
|
||||
|
||||
def assert_snapshot(
|
||||
name: str,
|
||||
content: str,
|
||||
metadata: dict,
|
||||
) -> None:
|
||||
"""Compare content against a saved snapshot, or create one if missing.
|
||||
|
||||
Args:
|
||||
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
|
||||
content: The actual inference output to compare.
|
||||
metadata: Additional context stored alongside content (model, seed, etc.).
|
||||
Not used for comparison -- purely documentary.
|
||||
|
||||
Raises:
|
||||
AssertionError: If content doesn't match the saved snapshot.
|
||||
|
||||
Environment:
|
||||
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
|
||||
"""
|
||||
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
|
||||
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
|
||||
|
||||
if snapshot_file.exists() and not update:
|
||||
snapshot = json.loads(snapshot_file.read_text())
|
||||
expected = snapshot["content"]
|
||||
if content != expected:
|
||||
diff = "\n".join(
|
||||
difflib.unified_diff(
|
||||
expected.splitlines(),
|
||||
content.splitlines(),
|
||||
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
|
||||
tofile="actual",
|
||||
lineterm="",
|
||||
)
|
||||
)
|
||||
raise AssertionError(
|
||||
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
|
||||
f"{diff}\n\n"
|
||||
f"Expected: {expected!r}\n"
|
||||
f"Actual: {content!r}\n\n"
|
||||
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py"
|
||||
)
|
||||
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")
|
||||
else:
|
||||
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_data = {**metadata, "arch": ARCH, "content": content}
|
||||
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
|
||||
action = "Updated" if update else "Created"
|
||||
print(f" {action} snapshot: {ARCH}/{snapshot_file.name}")
|
||||
22
e2e/test_cluster_formation.py
Normal file
22
e2e/test_cluster_formation.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Test: Basic cluster formation.
|
||||
|
||||
Verifies two nodes discover each other, elect a master, and the API responds.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("cluster_formation") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
print("PASSED: cluster_formation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
60
e2e/test_inference_snapshot.py
Normal file
60
e2e/test_inference_snapshot.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Test: Deterministic inference output (snapshot test).
|
||||
|
||||
Sends a chat completion request with a fixed seed,
|
||||
then verifies the output matches a known-good snapshot. This ensures
|
||||
inference produces consistent results across runs.
|
||||
|
||||
Uses MLX CPU backend in Docker on x86 Linux.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "What is 2+2? Reply with just the number."
|
||||
MAX_TOKENS = 32
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("inference_snapshot") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="inference_snapshot",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: inference_snapshot")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
47
e2e/test_no_internet.py
Normal file
47
e2e/test_no_internet.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Test: Cluster works without internet access.
|
||||
|
||||
Verifies exo functions correctly when containers can talk to each other
|
||||
but cannot reach the internet. Uses iptables to block all outbound traffic
|
||||
except private subnets and multicast (for mDNS discovery).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster(
|
||||
"no_internet",
|
||||
overrides=["tests/no_internet/docker-compose.override.yml"],
|
||||
) as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
# Verify internet is actually blocked from inside the containers
|
||||
for node in ["exo-node-1", "exo-node-2"]:
|
||||
rc, _ = await cluster.exec(
|
||||
node,
|
||||
"curl",
|
||||
"-sf",
|
||||
"--max-time",
|
||||
"3",
|
||||
"https://huggingface.co",
|
||||
check=False,
|
||||
)
|
||||
assert rc != 0, f"{node} should not be able to reach the internet"
|
||||
print(f" {node}: internet correctly blocked")
|
||||
|
||||
# Verify exo detected no internet connectivity
|
||||
log = await cluster.logs()
|
||||
assert "Internet connectivity: False" in log, "exo should detect no internet"
|
||||
print(" exo correctly detected no internet connectivity")
|
||||
|
||||
print("PASSED: no_internet")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
58
e2e/test_snapshot_code_gen.py
Normal file
58
e2e/test_snapshot_code_gen.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Test: Code generation snapshot.
|
||||
|
||||
Verifies deterministic output for a code generation prompt.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = (
|
||||
"Write a Python function to reverse a string. Only output the code, no explanation."
|
||||
)
|
||||
MAX_TOKENS = 64
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_code_gen") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="snapshot_code_gen",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_code_gen")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
63
e2e/test_snapshot_edge.py
Normal file
63
e2e/test_snapshot_edge.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Test: Edge case snapshots.
|
||||
|
||||
Verifies deterministic output for edge-case prompts: single word input,
|
||||
special characters, and unicode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
MAX_TOKENS = 32
|
||||
|
||||
CASES = [
|
||||
("edge_single_word", "Hi"),
|
||||
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
|
||||
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_edge") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
for snapshot_name, prompt in CASES:
|
||||
print(f" [{snapshot_name}] Sending: {prompt!r}")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" [{snapshot_name}] Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name=snapshot_name,
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": prompt,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_edge")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
56
e2e/test_snapshot_long_output.py
Normal file
56
e2e/test_snapshot_long_output.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Test: Longer output snapshot.
|
||||
|
||||
Verifies deterministic output with a higher max_tokens (128).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "Explain how a binary search algorithm works."
|
||||
MAX_TOKENS = 128
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_long_output") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="snapshot_long_output",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_long_output")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
72
e2e/test_snapshot_multi_model.py
Normal file
72
e2e/test_snapshot_multi_model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Test: Multi-model snapshot tests.
|
||||
slow
|
||||
|
||||
Verifies deterministic output across different model architectures to catch
|
||||
model-specific regressions. Each model uses its own snapshot file.
|
||||
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
SEED = 42
|
||||
PROMPT = "What is the capital of France?"
|
||||
MAX_TOKENS = 32
|
||||
|
||||
MODELS = [
|
||||
"mlx-community/SmolLM2-135M-Instruct",
|
||||
"mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"mlx-community/gemma-2-2b-it-4bit",
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_multi_model") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
for model in MODELS:
|
||||
short_name = (
|
||||
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
|
||||
)
|
||||
snapshot_name = f"snapshot_multi_{short_name}"
|
||||
|
||||
print(f" Launching model {model}...")
|
||||
await cluster.place_model(model)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" [{short_name}] Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name=snapshot_name,
|
||||
content=content,
|
||||
metadata={
|
||||
"model": model,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print(f" [{short_name}] PASSED")
|
||||
|
||||
print("PASSED: snapshot_multi_model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
56
e2e/test_snapshot_reasoning.py
Normal file
56
e2e/test_snapshot_reasoning.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Test: Reasoning/math snapshot.
|
||||
|
||||
Verifies deterministic output for a simple reasoning prompt.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
|
||||
MAX_TOKENS = 64
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_reasoning") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="snapshot_reasoning",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_reasoning")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
@@ -0,0 +1,32 @@
|
||||
# Block all outbound internet traffic using iptables while preserving:
|
||||
# - Multicast (224.0.0.0/4) for mDNS peer discovery
|
||||
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
|
||||
# - Loopback (127/8)
|
||||
# Requires NET_ADMIN capability for iptables.
|
||||
services:
|
||||
exo-node-1:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
exo-node-2:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/SmolLM2-135M-Instruct"
|
||||
n_layers = 30
|
||||
hidden_size = 576
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "llama"
|
||||
quantization = "bf16"
|
||||
base_model = "SmolLM2 135M"
|
||||
capabilities = ["text"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 269060381
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-4bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-6bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/gemma-2-2b-it-4bit"
|
||||
n_layers = 26
|
||||
hidden_size = 2304
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
family = "gemma2"
|
||||
quantization = "4bit"
|
||||
base_model = "Gemma 2 2B"
|
||||
capabilities = ["text"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1492755242
|
||||
@@ -79,6 +79,7 @@ def chat_request_to_text_generation(
|
||||
seed=request.seed,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
enable_thinking=request.enable_thinking,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
|
||||
@@ -189,6 +189,7 @@ class ConfigData(BaseModel):
|
||||
["MiniMaxM2ForCausalLM"],
|
||||
["LlamaForCausalLM"],
|
||||
["GptOssForCausalLM"],
|
||||
["Step3p5ForCausalLM"],
|
||||
]
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -199,6 +199,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
enable_thinking: bool | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
|
||||
@@ -40,5 +40,6 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
chat_template_messages: list[dict[str, Any]] | None = None
|
||||
enable_thinking: bool | None = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: int | None = None
|
||||
|
||||
@@ -35,6 +35,9 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.step3p5 import Model as Step35Model
|
||||
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
|
||||
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
@@ -264,6 +267,19 @@ def pipeline_auto_parallel(
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, Step35InnerModel):
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
sliding_layers = [
|
||||
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
|
||||
]
|
||||
full_layers = [
|
||||
i
|
||||
for i, layer in enumerate(layers)
|
||||
if not getattr(layer, "is_sliding", True)
|
||||
]
|
||||
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
|
||||
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -427,6 +443,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Step35Model):
|
||||
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -981,3 +1005,46 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Step35Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_heads //= self.N
|
||||
layer.self_attn.num_kv_heads //= self.N
|
||||
|
||||
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
|
||||
layer.self_attn.g_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.g_proj
|
||||
)
|
||||
|
||||
if isinstance(layer.mlp, Step35MLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
else:
|
||||
layer.mlp.sharding_group = self.group
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
@@ -462,11 +462,19 @@ def apply_chat_template(
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if task_params.enable_thinking is not None:
|
||||
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
||||
# Jinja ignores unknown variables, so passing both is safe.
|
||||
extra_kwargs["enable_thinking"] = task_params.enable_thinking
|
||||
extra_kwargs["thinking"] = task_params.enable_thinking
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=task_params.tools,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if partial_assistant_content:
|
||||
|
||||
Reference in New Issue
Block a user