Compare commits

..

16 Commits

Author SHA1 Message Date
Evan
103cbdee58 shfmt ignore 2026-02-06 17:08:26 +00:00
Evan
dbcc829625 sdk version? 2026-02-06 16:58:18 +00:00
Evan
30b384e2e6 vendor the apple sdk nix package 2026-02-06 16:40:10 +00:00
Evan
6675feed71 enable MLX_BUILD_CPU in nix built mlx 2026-02-06 16:14:04 +00:00
Evan Quiney
9b5cae3db6 auto bench (#1405)
runs exo_bench remotely with some nice git QoL

## usage
run tests/auto_bench.sh host1 [host2]

exo bench will be run on those hosts and its output saved to
bench/commit_hash/*.json on all models currently downloaded
2026-02-06 15:35:46 +00:00
Jake Hillion
cf7201f91e pyproject: set minimum uv version
The uv.lock is churning constantly as different UV versions bounce it
between revisions. This is made worse by GitHub automatically hiding the
uv.lock changes, meaning it's hard to notice when this went wrong.

Set a minimum version for `uv` in pyproject.toml to fix this. I tried
quite a few versions (not all) and found 0.8.6 sets the revision to 3,
which I believe is the latest. This is from August 2025 so has been
around for a while.

Test plan:

```
jake@maverick:/data/users/jake/repos/exo/ > git checkout main uv.lock
jake@maverick:/data/users/jake/repos/exo/ > nix shell github:nixos/nixpkgs/3dce7f4a77812afd69efcbfe15e5223f98c5c69e#uv --command sh -c 'uv add pip --frozen && uv lock && uv remove pip --frozen && uv lock && uv --version'

Resolved 140 packages in 147ms
Added pip v26.0.1
Resolved 139 packages in 48ms
Removed pip v26.0.1
uv 0.8.6
```
2026-02-06 15:28:10 +00:00
rltakashige
b315035ae0 Add minimax and fix qwen sharding strategies (#1318)
## Motivation

MiniMax tensor sharding does not provide equivalent outputs to running
it as a single node because RMSNorm weights cannot be split without
affecting the output.

Qwen3Next sharding was broken, and something with Qwen3MoE was likely
changed upstream, as several variables no longer exist.

This also ballooned into fixing prefix caching for non-standard models
as Qwen3Next was behaving weirdly.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Worked for a 8 hour long eval at the same performance and a more similar
completion/reasoning token distribution.

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
2026-02-06 13:26:59 +00:00
rltakashige
c8dbbee27b skip tensor ring on bench (#1403)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-06 13:06:59 +00:00
rltakashige
f0107e9670 Fix offline no cache (#1402)
## Motivation

In offline mode, exo complains if there is no caches directory, even if
the files are there.

## Changes

Check safetensors index and the directory structure to build caches
directory.

## Test Plan

### Manual Testing
<img width="2338" height="1102" alt="image"
src="https://github.com/user-attachments/assets/ad769911-399b-4fca-ac80-aeaa046af06b"
/>
<img width="656" height="1668" alt="image"
src="https://github.com/user-attachments/assets/6080986c-3904-4600-a340-8c70f1b33266"
/>
2026-02-06 12:57:01 +00:00
Hunter Bown
9f502793c1 fix: retry downloads on transient errors instead of breaking (#1398)
## Motivation

`download_file_with_retry()` has a `break` in the generic exception
handler that exits the retry loop after the first transient failure.
This means network timeouts, connection resets, and server errors all
cause an immediate download failure — the two remaining retry attempts
never run.

## Changes

**download_utils.py**: Replaced `break` with logging and exponential
backoff in the generic exception handler, matching the existing
rate-limit handler behavior.

Before:
```python
except Exception as e:
    on_connection_lost()
    if attempt == n_attempts - 1:
        raise e
    break  # exits loop immediately
```

After:
```python
except Exception as e:
    on_connection_lost()
    if attempt == n_attempts - 1:
        raise e
    logger.error(f"Download error on attempt {attempt + 1}/{n_attempts} ...")
    logger.error(traceback.format_exc())
    await asyncio.sleep(2.0**attempt)
```

## Why It Works

The `break` statement was bypassing the retry mechanism entirely.
Replacing it with the same log-and-backoff pattern used by the
`HuggingFaceRateLimitError` handler means all 3 attempts are actually
used before giving up. The exponential backoff (1s, 2s) gives transient
issues time to resolve between attempts.

## Test Plan

### Manual Testing
- Downloads that hit transient network errors now retry instead of
failing immediately

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `uv run pytest src/exo/download/tests/ -v` — 11 tests pass

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-06 11:51:54 +00:00
Evan Quiney
c8371349d5 add scripts (#1401)
allow running exo-bench and the headless runner from nix
2026-02-06 11:06:40 +00:00
Evan Quiney
6b907398a4 cancel downloads for deleted instances (#1393)
after deleting an instance, if a given (node_id, model_id) pair doesn't exist in the left over instances, cancel the download of model_id on node_id.
2026-02-05 18:16:43 +00:00
Evan Quiney
572e647908 better cancellation (#1388)
a lot of our cleanup logic wasn't running leading to bad shutdown states

## changes
- added `try: except` blocks around most task groups
- made the runner shutdown code synchronous
- abandon the MpReceiver's recv_async thread on cancellation
- this only occurs during runner shutdown, the queue closing from the
other end should terminate the mp.Queue, cleaning up the thread in its
own time. i could try other methods if this is not sufficient.

## outcome
ctrl-c just works now! minus the tokio panic of course :) no more
hypercorn lifespan errors though!
2026-02-05 15:22:33 +00:00
Evan Quiney
e59ebd986d set exo as the nix default package (#1391)
!!!
2026-02-05 15:15:52 +00:00
Alex Cheema
5c2f29f3f2 feat: show download availability in model picker (#1377)
## Motivation

Users browsing models in the picker need to know which models are
already downloaded and ready to run on their cluster, without having to
check the downloads page separately.

## Changes

- **ModelPickerModal.svelte**: Computes per-model download availability
by checking which nodes have `DownloadCompleted` entries and summing
their total RAM against the model's storage size. Passes availability
data to `ModelPickerGroup`. Enhances the info modal with a "Downloaded
on:" section showing node friendly names with green badges.
- **ModelPickerGroup.svelte**: Accepts new `downloadStatus` prop. Shows
a green checkmark-in-circle icon next to models that are downloaded on
sufficient nodes. Tooltip shows which nodes have the model.
- **+page.svelte**: Passes `downloadsData` and `topologyNodes` to
`ModelPickerModal`.

## Why It Works

The download state from `/state` already tracks per-node completed
downloads. The shared `getNodesWithModelDownloaded()` utility (from PR
#1375) finds nodes with `DownloadCompleted` entries for each model.
Total RAM is summed from the topology node data (using `ram_total`, not
`ram_available`) and compared to the model's `storage_size_megabytes` to
determine if there's enough aggregate memory. This is intentionally a
simple heuristic — not a full placement preview.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Open the model picker modal
- Verify downloaded models show a green checkmark icon
- Verify the checkmark appears dimmer for models downloaded on nodes
with insufficient total RAM
- Click the (i) info button on a downloaded model
- Verify "Downloaded on:" section appears with correct node names
- Verify models with no downloads show no indicator

### Automated Testing
- Dashboard builds successfully (`npm run build`)
- No new Python changes requiring type checking

> **Note:** This is a chained PR. Base branch is
`alexcheema/topology-download-indicators` (#1375).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 14:32:53 +00:00
Alex Cheema
ffe6396c91 Add Qwen3-Coder-Next model cards (#1367)
## Motivation

Qwen3-Coder-Next just dropped on mlx-community in several quantizations.
It's an 80B MoE model (Qwen3NextForCausalLM) which we already have
tensor parallelism support for via QwenShardingStrategy — just needs
model cards.

## Changes

Added model cards for all 5 available quantizations:
- `mlx-community/Qwen3-Coder-Next-4bit` (~46GB)
- `mlx-community/Qwen3-Coder-Next-5bit` (~58GB)
- `mlx-community/Qwen3-Coder-Next-6bit` (~69GB)
- `mlx-community/Qwen3-Coder-Next-8bit` (~89GB)
- `mlx-community/Qwen3-Coder-Next-bf16` (~158GB)

All with `supports_tensor = true` since the architecture is already
supported.

## Why It Works

`Qwen3NextForCausalLM` is already handled by QwenShardingStrategy in
auto_parallel.py and is in the supports_tensor allowlist in
model_cards.py. No code changes needed — just the TOML card files.

## Test Plan

### Manual Testing
<!-- n/a - model card addition only -->

### Automated Testing
- `basedpyright` — 0 errors
- `ruff check` — passes
- `nix fmt` — no changes
- `pytest` — 173 passed, 1 skipped


🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 13:37:18 +00:00
81 changed files with 4892 additions and 3069 deletions

3
.gitignore vendored
View File

@@ -32,3 +32,6 @@ dashboard/.svelte-kit/
# host config snapshots
hosts_*.json
.swp
# bench files
bench/**/*.json

View File

@@ -1139,7 +1139,7 @@ class array:
) -> array:
"""See :func:`flatten`."""
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
"""
Equivalent to :func:`reshape` but the shape can be passed either as a
:obj:`tuple` or as separate arguments.
@@ -1222,7 +1222,7 @@ class array:
) -> array:
"""See :func:`swapaxes`."""
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
"""
Equivalent to :func:`transpose` but the axes can be passed either as
a tuple or as separate arguments.

View File

@@ -30,6 +30,9 @@ class Conv1d(Module):
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""
weight: mx.array
groups: int
def __init__(
self,
in_channels: int,

View File

@@ -11,7 +11,10 @@ import mlx.core as mx
class Cache(Protocol):
keys: mx.array
values: mx.array
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
offset: int
def update_and_fetch(
self, keys: mx.array, values: mx.array
) -> tuple[mx.array, mx.array]: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter
@@ -87,6 +90,7 @@ def create_attention_mask(
class _BaseCache(Cache):
keys: mx.array
values: mx.array
offset: int
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.mla import MultiLinear
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module):
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
# kv_b_proj: nn.Linear
embed_q: MultiLinear
unembed_out: MultiLinear
o_proj: nn.Linear
rope: Any

View File

@@ -0,0 +1,114 @@
"""Type stubs for mlx_lm.models.qwen3_next"""
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .switch_layers import SwitchGLU
class Qwen3NextMLP(nn.Module):
gate_proj: nn.Linear
down_proj: nn.Linear
up_proj: nn.Linear
def __init__(self, dim: int, hidden_dim: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Qwen3NextGatedDeltaNet(nn.Module):
hidden_size: int
num_v_heads: int
num_k_heads: int
head_k_dim: int
head_v_dim: int
key_dim: int
value_dim: int
conv_kernel_size: int
conv_dim: int
conv1d: nn.Conv1d
in_proj_qkvz: nn.Linear
in_proj_ba: nn.Linear
dt_bias: mx.array
A_log: mx.array
out_proj: nn.Linear
def __init__(self, config: Any) -> None: ...
def __call__(
self,
inputs: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Qwen3NextAttention(nn.Module):
num_attention_heads: int
num_key_value_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
def __init__(self, args: Any) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Qwen3NextSparseMoeBlock(nn.Module):
norm_topk_prob: bool
num_experts: int
top_k: int
gate: nn.Linear
switch_mlp: SwitchGLU
shared_expert: Qwen3NextMLP
shared_expert_gate: nn.Linear
def __init__(self, args: Any) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Qwen3NextDecoderLayer(nn.Module):
is_linear: bool
linear_attn: Qwen3NextGatedDeltaNet
self_attn: Qwen3NextAttention
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock
def __init__(self, args: Any, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Qwen3NextModel(nn.Module):
embed_tokens: nn.Embedding
layers: list[Qwen3NextDecoderLayer]
norm: nn.RMSNorm
def __init__(self, args: Any) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: Qwen3NextModel
lm_head: nn.Linear
def __init__(self, args: Any) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[Qwen3NextDecoderLayer]: ...

View File

@@ -113,6 +113,10 @@ class TokenizerWrapper:
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
think_start: str | None
think_end: str | None
think_start_id: int | None
think_end_id: int | None
def __init__(
self,

View File

@@ -431,7 +431,12 @@ def main() -> int:
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Pipeline jaccl is often pointless, skip by default",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
@@ -450,6 +455,7 @@ def main() -> int:
default="bench/results.json",
help="Write raw per-run results JSON to this path.",
)
ap.add_argument("--stdout", action="store_true", help="Write results to stdout")
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
@@ -533,6 +539,16 @@ def main() -> int:
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
@@ -652,7 +668,9 @@ def main() -> int:
time.sleep(5)
if args.json_out:
if args.stdout:
json.dump(all_rows, sys.stdout, indent=2, ensure_ascii=False)
elif args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump(all_rows, f, indent=2, ensure_ascii=False)
logger.debug(f"\nWrote results JSON: {args.json_out}")

View File

@@ -14,6 +14,7 @@
isAdding: boolean;
onAdd: () => void;
onSelect: () => void;
downloadedOnNodes?: string[];
};
let {
@@ -22,6 +23,7 @@
isAdding,
onAdd,
onSelect,
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
@@ -45,6 +47,28 @@
<span class="text-sm font-mono text-white truncate" title={model.id}
>{modelName}</span
>
{#if downloadedOnNodes.length > 0}
<span
class="flex-shrink-0"
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{#if isAdded}
<span
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"

View File

@@ -5,6 +5,7 @@
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
type ModelFilterPopoverProps = {
@@ -148,6 +149,36 @@
</div>
</div>
<!-- Downloaded only -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
? 'bg-green-500/20 text-green-400 border border-green-500/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() =>
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
>
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="ml-1">Downloaded</span>
</button>
</div>
<!-- Size range -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>

View File

@@ -21,6 +21,12 @@
hasMultipleVariants: boolean;
}
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
@@ -31,6 +37,7 @@
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatusMap?: Map<string, DownloadAvailability>;
};
let {
@@ -43,8 +50,19 @@
onSelectModel,
onToggleFavorite,
onShowInfo,
downloadStatusMap,
}: ModelPickerGroupProps = $props();
// Group-level download status: show if any variant is downloaded
const groupDownloadStatus = $derived.by(() => {
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
// Return the first available entry (prefer "available" ones)
for (const avail of downloadStatusMap.values()) {
if (avail.available) return avail;
}
return downloadStatusMap.values().next().value;
});
// Format storage size
function formatSize(mb: number | undefined): string {
if (!mb) return "";
@@ -198,10 +216,42 @@
</span>
{/if}
<!-- Variant count -->
<!-- Variant count with size range -->
{#if group.hasMultipleVariants}
{@const sizes = group.variants
.map((v) => v.storage_size_megabytes || 0)
.filter((s) => s > 0)
.sort((a, b) => a - b)}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{group.variants.length} variants
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
sizes[0],
)}-{formatSize(sizes[sizes.length - 1])}){/if}
</span>
{/if}
<!-- Download availability indicator -->
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
<span
class="flex-shrink-0"
title={groupDownloadStatus.available
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
@@ -305,6 +355,33 @@
{formatSize(variant.storage_size_megabytes)}
</span>
<!-- Download indicator for this variant -->
{#if downloadStatusMap?.get(variant.id)}
{@const variantDl = downloadStatusMap.get(variant.id)}
{#if variantDl}
<span
class="flex-shrink-0"
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
>
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{/if}
<!-- Check mark if selected -->
{#if isSelected}
<svg

View File

@@ -5,6 +5,7 @@
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
interface ModelInfo {
id: string;
@@ -33,6 +34,7 @@
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
interface HuggingFaceModel {
@@ -58,6 +60,15 @@
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
downloadsData?: Record<string, unknown[]>;
topologyNodes?: Record<
string,
{
friendly_name?: string;
system_info?: { model_id?: string };
macmon_info?: { memory?: { ram_total?: number } };
}
>;
};
let {
@@ -74,6 +85,8 @@
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
downloadsData,
topologyNodes,
}: ModelPickerModalProps = $props();
// Local state
@@ -81,9 +94,75 @@
let selectedFamily = $state<string | null>(null);
let expandedGroups = $state<Set<string>>(new Set());
let showFilters = $state(false);
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
let filters = $state<FilterState>({
capabilities: [],
sizeRange: null,
downloadedOnly: false,
});
let infoGroup = $state<ModelGroup | null>(null);
// Download availability per model group
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
function getNodeName(nodeId: string): string {
const node = topologyNodes?.[nodeId];
return (
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
);
}
const modelDownloadAvailability = $derived.by(() => {
const result = new Map<string, DownloadAvailability>();
if (!downloadsData || !topologyNodes) return result;
for (const model of models) {
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
if (nodeIds.length === 0) continue;
// Sum total RAM across nodes that have the model
let totalRamBytes = 0;
for (const nodeId of nodeIds) {
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
}
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
result.set(model.id, {
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
nodeNames: nodeIds.map(getNodeName),
nodeIds,
});
}
return result;
});
// Aggregate download availability per group (available if ANY variant is available)
function getGroupDownloadAvailability(
group: ModelGroup,
): DownloadAvailability | undefined {
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) return avail;
}
return undefined;
}
// Get per-variant download map for a group
function getVariantDownloadMap(
group: ModelGroup,
): Map<string, DownloadAvailability> {
const map = new Map<string, DownloadAvailability>();
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
}
return map;
}
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
@@ -95,15 +174,12 @@
let manualModelId = $state("");
let addModelError = $state<string | null>(null);
// Reset state when modal opens
// Reset transient state when modal opens, but preserve tab selection
$effect(() => {
if (isOpen) {
searchQuery = "";
selectedFamily = null;
expandedGroups = new Set();
showFilters = false;
hfSearchQuery = "";
hfSearchResults = [];
manualModelId = "";
addModelError = null;
}
@@ -339,6 +415,16 @@
});
}
// Filter to downloaded models only
if (filters.downloadedOnly) {
result = result.filter((g) =>
g.variants.some((v) => {
const avail = modelDownloadAvailability.get(v.id);
return avail && avail.nodeIds.length > 0;
}),
);
}
// Sort: models that fit first, then by size (largest first)
result.sort((a, b) => {
const aFits = a.variants.some((v) => canModelFit(v.id));
@@ -385,11 +471,13 @@
}
function clearFilters() {
filters = { capabilities: [], sizeRange: null };
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
}
const hasActiveFilters = $derived(
filters.capabilities.length > 0 || filters.sizeRange !== null,
filters.capabilities.length > 0 ||
filters.sizeRange !== null ||
filters.downloadedOnly,
);
</script>
@@ -576,6 +664,12 @@
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
downloadedOnNodes={downloadsData
? getNodesWithModelDownloaded(
downloadsData,
model.id,
).map(getNodeName)
: []}
/>
{/each}
{/if}
@@ -650,6 +744,7 @@
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatusMap={getVariantDownloadMap(group)}
/>
{/each}
{/if}
@@ -667,6 +762,11 @@
>{cap}</span
>
{/each}
{#if filters.downloadedOnly}
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
>Downloaded</span
>
{/if}
{#if filters.sizeRange}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
@@ -742,6 +842,40 @@
</div>
</div>
{/if}
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
{#if infoDownload}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-1">
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="text-white/40">Downloaded on:</span>
</div>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoDownload.nodeNames as nodeName}
<span
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
>
{nodeName}
</span>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>
{/if}

View File

@@ -0,0 +1,152 @@
/**
* Shared utilities for parsing and querying download state.
*
* The download state from `/state` is shaped as:
* Record<NodeId, Array<TaggedDownloadEntry>>
*
* Each entry is a tagged union object like:
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
*/
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
function unwrapTagged(
obj: Record<string, unknown>,
): [string, Record<string, unknown>] | null {
const keys = Object.keys(obj);
if (keys.length !== 1) return null;
const tag = keys[0];
const payload = obj[tag];
if (!payload || typeof payload !== "object") return null;
return [tag, payload as Record<string, unknown>];
}
/** Extract the model ID string from a download entry's nested shard_metadata. */
export function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
if (!unwrapped) return null;
const [, shardData] = unwrapped;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
/** Extract the shard_metadata object from a download entry payload. */
export function extractShardMetadata(
downloadPayload: Record<string, unknown>,
): Record<string, unknown> | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
return shardMetadata as Record<string, unknown>;
}
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
export function getDownloadTag(
entry: unknown,
): [string, Record<string, unknown>] | null {
if (!entry || typeof entry !== "object") return null;
return unwrapTagged(entry as Record<string, unknown>);
}
/**
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
*/
function* iterNodeDownloads(
nodeDownloads: unknown[],
): Generator<[string, Record<string, unknown>, string]> {
for (const entry of nodeDownloads) {
const tagged = getDownloadTag(entry);
if (!tagged) continue;
const [tag, payload] = tagged;
const modelId = extractModelIdFromDownload(payload);
if (!modelId) continue;
yield [tag, payload, modelId];
}
}
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
export function isModelDownloadedOnNode(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): boolean {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return false;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
}
return false;
}
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
export function getNodesWithModelDownloaded(
downloadsData: Record<string, unknown[]>,
modelId: string,
): string[] {
const result: string[] = [];
for (const nodeId of Object.keys(downloadsData)) {
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
result.push(nodeId);
}
}
return result;
}
/**
* Find shard metadata for a model from any download entry across all nodes.
* Returns the first match found (completed entries are preferred).
*/
export function getShardMetadataForModel(
downloadsData: Record<string, unknown[]>,
modelId: string,
): Record<string, unknown> | null {
let fallback: Record<string, unknown> | null = null;
for (const nodeDownloads of Object.values(downloadsData)) {
if (!Array.isArray(nodeDownloads)) continue;
for (const [tag, payload, entryModelId] of iterNodeDownloads(
nodeDownloads,
)) {
if (entryModelId !== modelId) continue;
const shard = extractShardMetadata(payload);
if (!shard) continue;
if (tag === "DownloadCompleted") return shard;
if (!fallback) fallback = shard;
}
}
return fallback;
}
/**
* Get the download status tag for a specific model on a specific node.
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
*/
export function getModelDownloadStatus(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): string | null {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return null;
let best: string | null = null;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (entryModelId !== modelId) continue;
if (tag === "DownloadCompleted") return tag;
if (tag === "DownloadOngoing") best = tag;
else if (!best) best = tag;
}
return best;
}

View File

@@ -3264,4 +3264,6 @@
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
{downloadsData}
topologyNodes={data?.nodes}
/>

View File

@@ -83,6 +83,9 @@
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
overlays = [
(final: _: { apple-sdk_26 = final.callPackage ./nix/apple-sdk/package.nix { darwinSdkMajorVersion = "26"; }; })
];
};
treefmt = {
projectRootFile = "flake.nix";
@@ -105,7 +108,10 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
shfmt = {
enable = true;
excludes = [ "nix/apple-sdk/**" ];
};
};
};
@@ -118,9 +124,15 @@
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit (self'.packages) metal-toolchain;
inherit uvLockMlxVersion;
};
default = self'.packages.exo;
sdk-version = pkgs.runCommand "sdk-version" { } ''
mkdir -p $out
echo ${pkgs.apple-sdk_26.version} > $out/version
'';
}
);

View File

@@ -20,7 +20,7 @@ sync-clean:
rust-rebuild:
cargo run --bin stub_gen
just sync-clean
uv sync --reinstall-package exo_pyo3_bindings
build-dashboard:
#!/usr/bin/env bash

0
nix/apple-sdk/README.md Normal file
View File

View File

@@ -0,0 +1,48 @@
{ lib
, fetchFromGitHub
, stdenvNoCC
,
}:
let
CoreSymbolication = stdenvNoCC.mkDerivation (finalAttrs: {
pname = "CoreSymbolication";
version = "0-unstable-2018-06-17";
src = fetchFromGitHub {
repo = "CoreSymbolication";
owner = "matthewbauer";
rev = "24c87c23664b3ee05dc7a5a87d647ae476a680e4";
hash = "sha256-PzvLq94eNhP0+rLwGMKcMzxuD6MlrNI7iT/eV0obtSE=";
};
patches = [
# Add missing symbol definitions needed to build `zlog` in system_cmds.
# https://github.com/matthewbauer/CoreSymbolication/pull/2
../patches/0001-Add-function-definitions-needed-to-build-zlog-in-sys.patch
../patches/0002-Add-CF_EXPORT-To-const-symbols.patch
];
dontBuild = true;
installPhase = ''
mkdir -p "$out/include"
cp *.h "$out/include"
'';
meta = {
description = "Reverse engineered headers for Apple's CoreSymbolication framework";
homepage = "https://github.com/matthewbauer/CoreSymbolication";
license = lib.licenses.mit;
teams = [ lib.teams.darwin ];
platforms = lib.platforms.darwin;
};
});
in
self: super: {
buildPhase = super.buildPhase or "" + ''
mkdir -p System/Library/PrivateFrameworks/CoreSymbolication.framework/Versions/A/Headers
ln -s Versions/Current/Headers System/Library/PrivateFrameworks/CoreSymbolication.framework/Headers
cp '${CoreSymbolication}/include/'*.h System/Library/PrivateFrameworks/CoreSymbolication.framework/Versions/A/Headers
'';
}

View File

@@ -0,0 +1,13 @@
{ lib, config }:
self: super: {
preBuild = super.preBuild or "" + ''
platformPath=$out/Platforms/MacOSX.platform
sdkpath=$platformPath/Developer/SDKs
'';
preInstall = super.preInstall or "" + ''
platformPath=$out/Platforms/MacOSX.platform
sdkpath=$platformPath/Developer/SDKs
'';
}

View File

@@ -0,0 +1,38 @@
{ lib
, fetchurl
, cpio
, pbzx
,
}:
{ urls
, version
, hash
,
}:
fetchurl {
pname = "macOS-SDK";
inherit version urls hash;
recursiveHash = true;
nativeBuildInputs = [
cpio
pbzx
];
postFetch = ''
renamed=$(mktemp -d)/sdk.xar
mv "$downloadedFile" "$renamed"
pbzx "$renamed" | cpio -idm
src=Library/Developer/CommandLineTools/SDKs/MacOSX${lib.versions.majorMinor version}.sdk
# Remove unwanted binaries, man pages, and folders from the SDK.
rm -rf $src/usr/bin $src/usr/share $src/System/Library/Perl
mkdir -p "$out"
cp -rd $src/* "$out"
'';
}

View File

@@ -0,0 +1,10 @@
{ makeSetupHook, sdkVersion }:
self: super: {
passthru = super.passthru or { } // {
privateFrameworksHook = makeSetupHook
{
name = "apple-sdk-private-frameworks-hook";
} ../setup-hooks/add-private-frameworks.sh;
};
}

View File

@@ -0,0 +1,38 @@
let
lockfile = builtins.fromJSON (builtins.readFile ../metadata/apple-oss-lockfile.json);
in
{ lib
, fetchFromGitHub
, stdenvNoCC
, sdkVersion
,
}:
let
sdkinfo = lockfile.${sdkVersion};
in
self: super: {
passthru = super.passthru or { } // {
# Returns the raw source from apple-oss-distributions repo.
# This is mostly useful for copying private headers needed to build other source releases.
#
# Note: The source releases are mostly not used to build the SDK. Unless they can be used to build binaries,
# theyre not used.
sourceRelease =
name:
let
lockinfo = sdkinfo.${name};
in
fetchFromGitHub
{
owner = "apple-oss-distributions";
repo = name;
rev = lockinfo.rev or "${name}-${lockinfo.version}";
inherit (lockinfo) hash;
}
// {
inherit (lockinfo) version;
};
};
}

View File

@@ -0,0 +1,327 @@
{ lib
, stdenvNoCC
, xcodePlatform
, sdkVersion
,
}:
let
inherit (lib.generators) toPlist;
Info = rec {
CFBundleIdentifier = "com.apple.platform.${Name}";
DefaultProperties = {
COMPRESS_PNG_FILES = "NO";
DEPLOYMENT_TARGET_SETTING_NAME = stdenvNoCC.hostPlatform.darwinMinVersionVariable;
STRIP_PNG_TEXT = "NO";
};
Description = if stdenvNoCC.hostPlatform.isMacOS then "macOS" else "iOS";
FamilyIdentifier = lib.toLower xcodePlatform;
FamilyName = Description;
Identifier = CFBundleIdentifier;
MinimumSDKVersion = stdenvNoCC.hostPlatform.darwinMinVersion;
Name = lib.toLower xcodePlatform;
Type = "Platform";
Version = sdkVersion;
};
# These files are all based off of Xcode spec files found in
# /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/Library/Xcode/PrivatePlugIns/IDEOSXSupportCore.ideplugin/Contents/Resources.
# Based off of the "MacOSX Architectures.xcspec" file. All i386 stuff
# is removed because NixPkgs only supports darwin-x86_64 and darwin-arm64.
Architectures = [
{
Identifier = "Standard";
Type = "Architecture";
Name = "Standard Architectures (Apple Silicon, 64-bit Intel)";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD";
}
{
Identifier = "Universal";
Type = "Architecture";
Name = "Universal (Apple Silicon, 64-bit Intel)";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD_32_64_BIT";
}
{
Identifier = "Native";
Type = "Architecture";
Name = "Native Architecture of Build Machine";
ArchitectureSetting = "NATIVE_ARCH_ACTUAL";
}
{
Identifier = "Standard64bit";
Type = "Architecture";
Name = "Apple Silicon, 64-bit Intel";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD_64_BIT";
}
{
Identifier = stdenvNoCC.hostPlatform.darwinArch;
Type = "Architecture";
Name = "Apple Silicon or Intel 64-bit";
}
{
Identifier = "Standard_Including_64_bit";
Type = "Architecture";
Name = "Standard Architectures (including 64-bit)";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD_INCLUDING_64_BIT";
}
];
# Based off of the "MacOSX Package Types.xcspec" file. Only keep the
# bare minimum needed.
PackageTypes = [
{
Identifier = "com.apple.package-type.mach-o-executable";
Type = "PackageType";
Name = "Mach-O Executable";
DefaultBuildSettings = {
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "compiled.mach-o.executable";
Name = "$(EXECUTABLE_NAME)";
};
}
{
Identifier = "com.apple.package-type.mach-o-objfile";
Type = "PackageType";
Name = "Mach-O Object File";
DefaultBuildSettings = {
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "compiled.mach-o.objfile";
Name = "$(EXECUTABLE_NAME)";
};
}
{
Identifier = "com.apple.package-type.mach-o-dylib";
Type = "PackageType";
Name = "Mach-O Dynamic Library";
DefaultBuildSettings = {
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "compiled.mach-o.dylib";
Name = "$(EXECUTABLE_NAME)";
};
}
{
Identifier = "com.apple.package-type.static-library";
Type = "PackageType";
Name = "Mach-O Static Library";
DefaultBuildSettings = {
EXECUTABLE_PREFIX = "lib";
EXECUTABLE_SUFFIX = ".a";
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "archive.ar";
Name = "$(EXECUTABLE_NAME)";
IsLaunchable = "NO";
};
}
{
Identifier = "com.apple.package-type.wrapper";
Type = "PackageType";
Name = "Wrapper";
DefaultBuildSettings = {
WRAPPER_SUFFIX = ".bundle";
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
CONTENTS_FOLDER_PATH = "$(WRAPPER_NAME)/Contents";
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/MacOS";
EXECUTABLE_PATH = "$(EXECUTABLE_FOLDER_PATH)/$(EXECUTABLE_NAME)";
INFOPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/Info.plist";
INFOSTRINGS_PATH = "$(LOCALIZED_RESOURCES_FOLDER_PATH)/InfoPlist.strings";
PKGINFO_PATH = "$(CONTENTS_FOLDER_PATH)/PkgInfo";
PBDEVELOPMENTPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/pbdevelopment.plist";
VERSIONPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/version.plist";
PUBLIC_HEADERS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Headers";
PRIVATE_HEADERS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/PrivateHeaders";
EXECUTABLES_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Executables";
FRAMEWORKS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Frameworks";
SHARED_FRAMEWORKS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/SharedFrameworks";
SHARED_SUPPORT_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/SharedSupport";
UNLOCALIZED_RESOURCES_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Resources";
LOCALIZED_RESOURCES_FOLDER_PATH = "$(UNLOCALIZED_RESOURCES_FOLDER_PATH)/$(DEVELOPMENT_LANGUAGE).lproj";
DOCUMENTATION_FOLDER_PATH = "$(LOCALIZED_RESOURCES_FOLDER_PATH)/Documentation";
PLUGINS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/PlugIns";
SCRIPTS_FOLDER_PATH = "$(UNLOCALIZED_RESOURCES_FOLDER_PATH)/Scripts";
};
ProductReference = {
FileType = "wrapper.cfbundle";
Name = "$(WRAPPER_NAME)";
IsLaunchable = "NO";
};
}
{
Identifier = "com.apple.package-type.wrapper.application";
Type = "PackageType";
BasedOn = "com.apple.package-type.wrapper";
Name = "Application Wrapper";
DefaultBuildSettings = {
GENERATE_PKGINFO_FILE = "YES";
};
ProductReference = {
FileType = "wrapper.application";
Name = "$(WRAPPER_NAME)";
IsLaunchable = "YES";
};
}
];
# Based off of the "MacOSX Product Types.xcspec" file. All
# bundles/wrapper are removed, because we prefer dynamic products in
# NixPkgs.
ProductTypes = [
{
Identifier = "com.apple.product-type.tool";
Type = "ProductType";
Name = "Command-line Tool";
PackageTypes = [ "com.apple.package-type.mach-o-executable" ];
}
{
Identifier = "com.apple.product-type.objfile";
Type = "ProductType";
Name = "Object File";
PackageTypes = [ "com.apple.package-type.mach-o-objfile" ];
}
{
Identifier = "com.apple.product-type.library.dynamic";
Type = "ProductType";
Name = "Dynamic Library";
PackageTypes = [ "com.apple.package-type.mach-o-dylib" ];
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(EXECUTABLE_NAME)";
MACH_O_TYPE = "mh_dylib";
REZ_EXECUTABLE = "YES";
EXECUTABLE_SUFFIX = ".$(EXECUTABLE_EXTENSION)";
EXECUTABLE_EXTENSION = "dylib";
DYLIB_COMPATIBILITY_VERSION = "1";
DYLIB_CURRENT_VERSION = "1";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "debugging";
GCC_INLINES_ARE_PRIVATE_EXTERN = "YES";
CODE_SIGNING_ALLOWED = "YES";
CODE_SIGNING_REQUIRED = "NO";
};
}
{
Identifier = "com.apple.product-type.library.static";
Type = "ProductType";
Name = "Static Library";
PackageTypes = [ "com.apple.package-type.static-library" ];
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(EXECUTABLE_NAME)";
MACH_O_TYPE = "staticlib";
REZ_EXECUTABLE = "YES";
EXECUTABLE_PREFIX = "lib";
EXECUTABLE_SUFFIX = ".$(EXECUTABLE_EXTENSION)";
EXECUTABLE_EXTENSION = "a";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "debugging";
SEPARATE_STRIP = "YES";
CLANG_ENABLE_MODULE_DEBUGGING = "NO";
};
}
{
Type = "ProductType";
Identifier = "com.apple.product-type.bundle";
Name = "Bundle";
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(WRAPPER_NAME)";
MACH_O_TYPE = "mh_bundle";
WRAPPER_PREFIX = "";
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
WRAPPER_EXTENSION = "bundle";
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "non-global";
};
PackageTypes = [ "com.apple.package-type.wrapper" ];
IsWrapper = "YES";
HasInfoPlist = "YES";
HasInfoPlistStrings = "YES";
}
{
Identifier = "com.apple.product-type.application";
Type = "ProductType";
BasedOn = "com.apple.product-type.bundle";
Name = "Application";
DefaultBuildProperties = {
MACH_O_TYPE = "mh_execute";
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
WRAPPER_EXTENSION = "app";
};
PackageTypes = [ "com.apple.package-type.wrapper.application" ];
}
{
Type = "ProductType";
Identifier = "com.apple.product-type.framework";
Name = "Bundle";
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(WRAPPER_NAME)";
MACH_O_TYPE = "mh_bundle";
WRAPPER_PREFIX = "";
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
WRAPPER_EXTENSION = "bundle";
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "non-global";
};
PackageTypes = [ "com.apple.package-type.wrapper" ];
IsWrapper = "YES";
HasInfoPlist = "YES";
HasInfoPlistStrings = "YES";
}
];
ToolchainInfo = {
Identifier = "com.apple.dt.toolchain.XcodeDefault";
};
in
{
"Info.plist" = builtins.toFile "Info.plist" (toPlist { escape = true; } Info);
"ToolchainInfo.plist" = builtins.toFile "ToolchainInfo.plist" (
toPlist { escape = true; } ToolchainInfo
);
"Architectures.xcspec" = builtins.toFile "Architectures.xcspec" (
toPlist { escape = true; } Architectures
);
"PackageTypes.xcspec" = builtins.toFile "PackageTypes.xcspec" (
toPlist { escape = true; } PackageTypes
);
"ProductTypes.xcspec" = builtins.toFile "ProductTypes.xcspec" (
toPlist { escape = true; } ProductTypes
);
}

View File

@@ -0,0 +1,40 @@
let
removedDylibs = [
# corecrypto is available under a very restrictive license (effectively: non-free, cant use).
# Without the headers and not being able to use corecrypto due to its license, its not very useful.
# Stubs are included in the SDK for all dylibs, including corecrypto. They should be removed.
"/usr/lib/system/libcorecrypto.dylib"
];
in
{ lib
, jq
, llvm
,
}:
self: super: {
nativeBuildInputs = super.nativeBuildInputs or [ ] ++ [
jq
llvm
];
buildPhase = super.buildPhase or "" + ''
echo "Removing the following dylibs from the libSystem reexported libraries list: ${lib.escapeShellArg (lib.concatStringsSep ", " removedDylibs)}"
for libSystem in libSystem.B.tbd libSystem.B_asan.tbd; do
# tbd-v5 is a JSON-based format, which can be manipulated by `jq`.
llvm-readtapi --filetype=tbd-v5 usr/lib/$libSystem \
| jq --argjson libs ${lib.escapeShellArg (builtins.toJSON removedDylibs)} '
if .libraries then
.libraries[] |= select(.install_names[] | any([.] | inside($libs)) | not)
else
.
end
| .main_library.reexported_libraries[].names[] |= select([.] | inside($libs) | not)
' > usr/lib/$libSystem~
# Convert libSystem back to tbd-v4 because not all tooling supports the JSON-based format yet.
llvm-readtapi --filetype=tbd-v4 usr/lib/$libSystem~ -o usr/lib/$libSystem
rm usr/lib/$libSystem~
done
'';
}

View File

@@ -0,0 +1,74 @@
{ lib
, cups
, darwin
, db
, libiconv
, ncurses
, stdenv
, stdenvNoCC
, xcbuild
,
}:
let
# CUPS has too many dependencies to build as part of the Darwin bootstrap. Its also typically taken as an explicit
# dependency by other packages, so building only the headers (to satisfy other SDK headers) should be okay.
cupsHeaders = darwin.bootstrapStdenv.mkDerivation {
pname = "${lib.getName cups}-headers";
version = lib.getVersion cups;
inherit (cups) src;
patches = cups.patches or [ ];
strictDeps = true;
dontBuild = true;
buildInputs = [ darwin.libresolv ]; # The `configure` script requires libresolv headers.
# CUPSs configure script fails to find `ar` when cross-compiling.
configureFlags = [ "ac_cv_path_AR=${stdenv.cc.targetPrefix}ar" ];
installTargets = [ "install-headers" ];
__structuredAttrs = true;
meta = {
inherit (cups.meta)
homepage
description
license
maintainers
platforms
;
};
};
in
self: super: {
# These packages are propagated only because other platforms include them in their libc (or otherwise by default).
# Reducing the number of special cases required to support Darwin makes supporting it easier for package authors.
propagatedBuildInputs =
super.propagatedBuildInputs or [ ]
++ [
libiconv
darwin.libresolv
darwin.libsbuf
# Shipped with the SDK only as a library with no headers
(lib.getLib darwin.libutil)
]
# x86_64-darwin links the object files from Csu when targeting very old releases
++ lib.optionals stdenvNoCC.hostPlatform.isx86_64 [ darwin.Csu ];
# The Darwin module for Swift requires certain headers to be included in the SDK (and not just be propagated).
buildPhase = super.buildPhase or "" + ''
for header in '${lib.getDev libiconv}/include/'* '${lib.getDev ncurses}/include/'* '${cupsHeaders}/include/'*; do
ln -s "$header" "usr/include/$(basename "$header")"
done
'';
# Exported to allow the headers to pass the requisites check in the stdenv bootstrap.
passthru = (super.passthru or { }) // {
cups-headers = cupsHeaders;
};
}

View File

@@ -0,0 +1,53 @@
{ lib
, pkgsBuildHost
, stdenv
, stdenvNoCC
, sdkVersion
,
}:
let
plists = import ./plists.nix {
inherit lib stdenvNoCC sdkVersion;
xcodePlatform = if stdenvNoCC.hostPlatform.isMacOS then "MacOSX" else "iPhoneOS";
};
inherit (pkgsBuildHost) darwin cctools xcbuild;
in
self: super: {
propagatedNativeBuildInputs = super.propagatedNativeBuildInputs or [ ] ++ [ xcbuild.xcrun ];
postInstall = super.postInstall or "" + ''
specspath=$out/Library/Xcode/Specifications
toolchainsPath=$out/Toolchains/XcodeDefault.xctoolchain
mkdir -p "$specspath" "$toolchainsPath"
# xcbuild expects to find things relative to the plist locations. If these are linked instead of copied,
# it wont find any platforms or SDKs.
cp '${plists."Info.plist"}' "$platformPath/Info.plist"
cp '${plists."ToolchainInfo.plist"}' "$toolchainsPath/ToolchainInfo.plist"
for spec in '${xcbuild}/Library/Xcode/Specifications/'*; do
ln -s "$spec" "$specspath/$(basename "$spec")"
done
cp '${plists."Architectures.xcspec"}' "$specspath/Architectures.xcspec"
cp '${plists."PackageTypes.xcspec"}' "$specspath/PackageTypes.xcspec"
cp '${plists."ProductTypes.xcspec"}' "$specspath/ProductTypes.xcspec"
mkdir -p "$out/usr/bin"
ln -s '${xcbuild.xcrun}/bin/xcrun' "$out/usr/bin/xcrun"
# Include `libtool` in the toolchain, so `xcrun -find libtool` can find it without requiring `cctools.libtool`
# as a `nativeBuildInput`.
mkdir -p "$toolchainsPath/usr/bin"
if [ -e '${cctools.libtool}/bin/${stdenv.cc.targetPrefix}libtool' ]; then
ln -s '${cctools.libtool}/bin/${stdenv.cc.targetPrefix}libtool' "$toolchainsPath/usr/bin/libtool"
fi
# Include additional binutils required by some packages (such as Chromium).
for tool in lipo nm otool size strip; do
if [ -e '${darwin.binutils-unwrapped}/bin/${stdenv.cc.targetPrefix}'$tool ]; then
ln -s '${darwin.binutils-unwrapped}/bin/${stdenv.cc.targetPrefix}'$tool "$toolchainsPath/usr/bin/$tool"
fi
done
'';
}

View File

@@ -0,0 +1,24 @@
let
disallowedPackages = builtins.fromJSON (builtins.readFile ../metadata/disallowed-packages.json);
in
{ lib
, jq
, stdenv
,
}:
self: super: {
# Remove headers and stubs for packages that are available in nixpkgs.
buildPhase = super.buildPhase or "" + ''
${lib.concatMapStringsSep "\n" (
pkg:
lib.concatLines (
[ ''echo "Removing headers and libraries from ${pkg.package}"'' ]
++ (map (header: "rm -rf -- usr/include/${header}") pkg.headers or [ ])
++ (map (framework: "rm -rf -- System/Library/Frameworks/${framework}") pkg.frameworks or [ ])
++ (map (library: "rm -rf -- usr/lib/${library}") pkg.libraries or [ ])
)
) disallowedPackages}
'';
}

View File

@@ -0,0 +1,9 @@
{}:
self: super: {
buildPhase = ''
runHook preBuild
${super.buildPhase or ""}
runHook postBuild
'';
}

View File

@@ -0,0 +1,536 @@
{
"14.4": {
"CarbonHeaders": {
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
"version": "18.1"
},
"CommonCrypto": {
"hash": "sha256-/VoOR9wJuKnmGE1CWGGXxX8SpmALHnEooNTa3QM+ITc=",
"version": "600028.100.1"
},
"IOAudioFamily": {
"hash": "sha256-VSk3jvsITJugtL67Qt0m4qJ879i7Fj6B/NGBFVCwpiU=",
"version": "540.3"
},
"IOBDStorageFamily": {
"hash": "sha256-UgLMsQBe1QLzlbScmPmASBN7VH4YBmNOUX2CEDezjmE=",
"version": "22"
},
"IOCDStorageFamily": {
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
"version": "61"
},
"IODVDStorageFamily": {
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
"version": "45"
},
"IOFWDVComponents": {
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
"version": "208"
},
"IOFireWireAVC": {
"hash": "sha256-IUytBKhhCgg0vtI+7q8d5kxpOUgO3tQD7TMy++jrorc=",
"version": "431"
},
"IOFireWireFamily": {
"hash": "sha256-W0KOF4hkA7kFOnL1ThAeFU/YlhFVqoqk9uzGjcBppX8=",
"version": "487"
},
"IOFireWireSBP2": {
"hash": "sha256-bItnRQIaGUxMyiU0q+4N8e5+jYiDEOUPmsrKhBFXvok=",
"version": "445"
},
"IOFireWireSerialBusProtocolTransport": {
"hash": "sha256-P7egeaD9SSa+YyrIRzM44gILKbIL7vezXK3M6q3MBOI=",
"version": "260"
},
"IOGraphics": {
"hash": "sha256-Ag37fd3tZJLXLVq1yzHOCWGOYYfwwTkC8hnvNaTEaWg=",
"version": "598"
},
"IOHIDFamily": {
"hash": "sha256-fmYTJsquAOBwzsgRmqPyjSJJi1hGcfnMmqLIcTe8W1s=",
"version": "2031.100.16"
},
"IOKitUser": {
"hash": "sha256-1bqRiLvyr2GQfbWwhXHXXIOtIka9YDw5GbKV6bd2k4k=",
"version": "100076.101.1"
},
"IONetworkingFamily": {
"hash": "sha256-J3cLeWKrQ8ypIaqgwRH9eU5JbjEDBVoezj3a2Lvwu5k=",
"version": "177"
},
"IOSerialFamily": {
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
"version": "93"
},
"IOStorageFamily": {
"hash": "sha256-cllpJX11c3CX8zEYdOT2TC63sx7NUAHh33yRHhrG2Ro=",
"version": "315"
},
"IOUSBFamily": {
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
"version": "630.4.5"
},
"Libc": {
"hash": "sha256-fxBM4KbPwQNVEJl7PCKP+1nUk9Oce/O2+0lVBxyngew=",
"version": "1592.100.35"
},
"Libinfo": {
"hash": "sha256-zZr6Mmou8Q+G6/wS+k0k7R+XirB94TNCUGS5dhi96ZE=",
"version": "583.0.1"
},
"Libm": {
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
"version": "2026"
},
"Libnotify": {
"hash": "sha256-7X+6S3C7ZOTXJUeDXOOg5EmoZyLZvtE06x3Is0TGgSU=",
"version": "317.100.2"
},
"Librpcsvc": {
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
"version": "31"
},
"Libsystem": {
"hash": "sha256-HsItciWrwyXujQ2hwqzv0JKOkkuynXYIqejLAEPJbMc=",
"version": "1345.100.2"
},
"OpenDirectory": {
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
"version": "146"
},
"Security": {
"hash": "sha256-NgTGbaw5JkpboDQpt1fSgUr9NYGS+bIOrEMQX7mLAME=",
"version": "61123.100.169"
},
"architecture": {
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
"version": "282"
},
"configd": {
"hash": "sha256-+3xesYxqfsNjWCW3T87OA7+Z1hBqmGEh/I8kP8Ajbso=",
"version": "1300.100.9"
},
"copyfile": {
"hash": "sha256-rSCTgzdHr7QmnPk9rJ9P4fOAolnEQv8PHfgAY+qA0s4=",
"version": "196.100.4"
},
"dtrace": {
"hash": "sha256-04Q35rCKnM5Csv5poFJKpK0VplWq4hvy251/Cb2Kl80=",
"version": "401.100.3"
},
"dyld": {
"hash": "sha256-6P/Da6xP19vmaCROoYv9pl7DaW3/U+qZBJT8PD33bn0=",
"version": "1160.6"
},
"eap8021x": {
"hash": "sha256-Ky6KSlJhyX1NRufGhVBcp+ZFmqYrAxwC/5QvJhC2PhU=",
"version": "354.100.3"
},
"hfs": {
"hash": "sha256-+YUVOttZU7C8I14CC6t3ZH2KxAjjTA2nB0y5bPgLxZM=",
"version": "650.0.2"
},
"launchd": {
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
"version": "842.1.4"
},
"libclosure": {
"hash": "sha256-M/jnIHzKYvdFCO0tJ1JXiD/UcZtJhLIoulaCQQUbn30=",
"version": "90"
},
"libdispatch": {
"hash": "sha256-igqIA5DMVHjG30WMHZZpYY7LRM9hZyMWItD+UxeTehY=",
"version": "1477.100.9"
},
"libmalloc": {
"hash": "sha256-Sh4/z7lGWRMldOPURkP5vLOAb5Ou6AUsVJEWz9wk9hI=",
"version": "521.100.59"
},
"libplatform": {
"hash": "sha256-gojt3sWOr7XO2yYI/B1CmNLTPFieSfoNtlOgQahOCok=",
"version": "316.100.10"
},
"libpthread": {
"hash": "sha256-phjfN8+IU8ibPsflR6LktnSi3giy89ghI+cFyrhiQNo=",
"version": "519.101.1"
},
"mDNSResponder": {
"hash": "sha256-0ECbWeMnIRTsi03BeBEe5boyR/84JJPbxzPQze8hHSA=",
"version": "2200.100.94.0.2"
},
"objc4": {
"hash": "sha256-eUVSpbyTEOMEdHoxSv6lZIZwB+cW/YWIaTZTcHgGOjo=",
"version": "912.3"
},
"ppp": {
"hash": "sha256-8+QUA79sHf85yvGSPE9qCmGsrZDT3NZnbgZVroJw/Hg=",
"version": "1016"
},
"removefile": {
"hash": "sha256-L6I0u8S3h3uV1veKA5HvkSebbBCd78ymlf//KWbebZo=",
"version": "70.100.4"
},
"xnu": {
"hash": "sha256-j5Ep1RX5DTJqTGszrF4d/JtzUqZ6nA6XoExqcIQ0RVQ=",
"version": "10063.101.15"
}
},
"15.5": {
"CarbonHeaders": {
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
"version": "18.1"
},
"CommonCrypto": {
"hash": "sha256-+qAwL6+s7di9cX/qXtapLkjCFoDuZaSYltRJEG4qekM=",
"version": "600035"
},
"IOAudioFamily": {
"hash": "sha256-VSk3jvsITJugtL67Qt0m4qJ879i7Fj6B/NGBFVCwpiU=",
"version": "600.2"
},
"IOBDStorageFamily": {
"hash": "sha256-s8hTwX0jq2iPULfBLUwpzqtszWuvJrrLGbmrKa/fY4U=",
"version": "24"
},
"IOCDStorageFamily": {
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
"version": "62"
},
"IODVDStorageFamily": {
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
"version": "46"
},
"IOFWDVComponents": {
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
"version": "208"
},
"IOFireWireAVC": {
"hash": "sha256-qR9lSTa7PN5Z9Nis4tfuXlcZGMIU48dete/NPD0UBbE=",
"version": "434"
},
"IOFireWireFamily": {
"hash": "sha256-hmErAXjLWIelqJaCrB8J4IiIxyB7S6EHFY+AY9YhmKQ=",
"version": "490"
},
"IOFireWireSBP2": {
"hash": "sha256-Xk+PDnUaO9q46nQwHwTKf/QXtGclfs0wTWiUbcV7e4s=",
"version": "452"
},
"IOFireWireSerialBusProtocolTransport": {
"hash": "sha256-P7egeaD9SSa+YyrIRzM44gILKbIL7vezXK3M6q3MBOI=",
"version": "261"
},
"IOGraphics": {
"hash": "sha256-iysZE42mOKZbFxSZBNspaBTCRKEKK38DFGBxZWQxZxI=",
"version": "599"
},
"IOHIDFamily": {
"hash": "sha256-gEYPyjXgQ2ABGufCKPjmzMdNRLxhELkCvOURCokyTO4=",
"version": "2115.100.21"
},
"IOKitUser": {
"hash": "sha256-p32U+jHfwA/tqnjF4p1BmojghEXK8KxiflW3IHs2iIY=",
"version": "100150.120.2"
},
"IONetworkingFamily": {
"hash": "sha256-gZ7Dkk4Iu7AV9K2ioqSeJ1W7bTNxv77bmT18iv3ljLg=",
"version": "185"
},
"IOSerialFamily": {
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
"version": "93"
},
"IOStorageFamily": {
"hash": "sha256-/0H0tqWUWkgYigYypucbc7lOCFYDuukwF9fvLEOhwOk=",
"version": "323"
},
"IOUSBFamily": {
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
"version": "630.4.5"
},
"Libc": {
"hash": "sha256-nWDokN0Vr5pUyNGculnDOah9RNgHiWr3S13RSQLmZrc=",
"version": "1698.100.8"
},
"Libinfo": {
"hash": "sha256-UI5mGvzZ6BPafGYD6CrNAJAKjeJLB6urAS2lpB6X/Ec=",
"version": "597"
},
"Libm": {
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
"version": "2026"
},
"Libnotify": {
"hash": "sha256-GDYMVi1034f9empq0YOuumQp/BDJ7phTb0Zl4KTY9xg=",
"version": "342"
},
"Librpcsvc": {
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
"version": "31"
},
"Libsystem": {
"hash": "sha256-nawWJiu2IJ34ek5iOX6CrlqMzev7TuJpUkvDp30ZQ/U=",
"version": "1351"
},
"OpenDirectory": {
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
"version": "146"
},
"Security": {
"hash": "sha256-ZOrOOCk+hZbzDilzkihpQfsDpzV3Ul4zy6fpFRWUQHw=",
"version": "61439.120.27"
},
"architecture": {
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
"version": "282"
},
"configd": {
"hash": "sha256-ZdUq1SrOwB88Lx68ekrA4zeVsLDZz4TAJywNnF+uAzY=",
"version": "1351.120.3"
},
"copyfile": {
"hash": "sha256-rLqT6e44W2ohgwUXREmiOyJBYCrV3gRLbtVnbUq60xc=",
"version": "221.121.1"
},
"dtrace": {
"hash": "sha256-iNEZyxK3DmEwO3gzrfvCaVZSEuuOMQm5IG/6FodPNdI=",
"version": "411"
},
"dyld": {
"hash": "sha256-4OOghgUYyMJbsTe96fiWCndTJ1BS94rK9v6Kqn/ooYs=",
"version": "1285.19"
},
"eap8021x": {
"hash": "sha256-Kx/wwnt108hDm0qQPyTNbZ8KoHkD5m7L4yb5qjSuQjI=",
"version": "365.120.2"
},
"hfs": {
"hash": "sha256-5/3Ycp3cKqlgAl1kjBmbF5tFlfJYQS5rbrbk4SS66b8=",
"version": "683.120.3"
},
"launchd": {
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
"version": "842.1.4"
},
"libclosure": {
"hash": "sha256-pvwfcbeEJmTEPdt6/lgVswiabLRG+sMN6VT5FwG7C4Q=",
"version": "96"
},
"libdispatch": {
"hash": "sha256-jTp2DolOOCQPBt1HRotkmPnKgQ2LGgniEqeHoM+vlKg=",
"version": "1521.120.4"
},
"libmalloc": {
"hash": "sha256-d9AVHSYTqHDlgctv8Hh4HAYW53MJelj4F8LWPsjrsws=",
"version": "715.120.13"
},
"libplatform": {
"hash": "sha256-gpijoTMvdkM0PdG8gyIllOJlh/MtTc4ro9ODDAhN6gM=",
"version": "349"
},
"libpthread": {
"hash": "sha256-N+MMXdbthsxauTTfZ5ElUs39dVH+Chn1yyU6pObZpkU=",
"version": "536"
},
"mDNSResponder": {
"hash": "sha256-ILx12PRxj/+VqfpCCErJFEJXFI9yzTh4g+FK0UCenIE=",
"version": "2600.120.12"
},
"objc4": {
"hash": "sha256-DMxa25gXjKCkiDnVJ/8SyJUjaBlmBGABg8EfCHcmTj0=",
"version": "940.4"
},
"ppp": {
"hash": "sha256-8+QUA79sHf85yvGSPE9qCmGsrZDT3NZnbgZVroJw/Hg=",
"version": "1016"
},
"removefile": {
"hash": "sha256-Z5UD0mk/s80CQB0PZWDzSl2JWXmnVmwUvlNb28+hR3k=",
"version": "81"
},
"xnu": {
"hash": "sha256-o4tCuCAIgAYg/Li3wTs12mVWr5C/4vbwu1zi+kJ9d6w=",
"version": "11417.121.6"
}
},
"26.0": {
"CarbonHeaders": {
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
"version": "18.1"
},
"CommonCrypto": {
"hash": "sha256-+qAwL6+s7di9cX/qXtapLkjCFoDuZaSYltRJEG4qekM=",
"version": "600035"
},
"IOAudioFamily": {
"hash": "sha256-A3iiAjjP29VdjMj40tLS5Q/ni4qeh9bBpnmNzeG2pIY=",
"version": "700.2"
},
"IOBDStorageFamily": {
"hash": "sha256-OcQUJ3nEfrpvWX/npnedJ4PECIGWFSLiM0PKoiH911w=",
"version": "26"
},
"IOCDStorageFamily": {
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
"version": "62"
},
"IODVDStorageFamily": {
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
"version": "46"
},
"IOFWDVComponents": {
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
"version": "208"
},
"IOFireWireAVC": {
"hash": "sha256-qR9lSTa7PN5Z9Nis4tfuXlcZGMIU48dete/NPD0UBbE=",
"version": "436"
},
"IOFireWireFamily": {
"hash": "sha256-hmErAXjLWIelqJaCrB8J4IiIxyB7S6EHFY+AY9YhmKQ=",
"version": "492"
},
"IOFireWireSBP2": {
"hash": "sha256-Xk+PDnUaO9q46nQwHwTKf/QXtGclfs0wTWiUbcV7e4s=",
"version": "454"
},
"IOFireWireSerialBusProtocolTransport": {
"hash": "sha256-cM/VFhVWNVwdJYk+mme0UYttQd7eJwd7Hlo7KNRyHY0=",
"version": "262"
},
"IOGraphics": {
"hash": "sha256-iysZE42mOKZbFxSZBNspaBTCRKEKK38DFGBxZWQxZxI=",
"version": "599"
},
"IOHIDFamily": {
"hash": "sha256-YLnabX90g4Q8LxjwVuJF6KODCDxychWV+VJaNG9d8fI=",
"version": "2222.0.24"
},
"IOKitUser": {
"hash": "sha256-ngwi8YMUqE0q8j7Lr5cqJwi2V+IDu3ie3bduotHIUJU=",
"version": "100222.0.4"
},
"IONetworkingFamily": {
"hash": "sha256-ZF5ML41Y1l1liQn32qTkcl4mMvx9Xdizb9VgvTzVTL4=",
"version": "186"
},
"IOSerialFamily": {
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
"version": "93"
},
"IOStorageFamily": {
"hash": "sha256-1FKSF622qeXPGngA3UmQ2M/IU1pdlMoYBPbXytUFDaQ=",
"version": "331"
},
"IOUSBFamily": {
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
"version": "630.4.5"
},
"Libc": {
"hash": "sha256-k+HQ+qgye0ORFm0hU8WzE4ysbbEoFZ7wcbVl5giDH/E=",
"version": "1725.0.11"
},
"Libinfo": {
"hash": "sha256-4InBEPi0n2EMo/8mIBib1Im4iTKRcRJ4IlAcLCigVGk=",
"version": "600"
},
"Libm": {
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
"version": "2026"
},
"Libnotify": {
"hash": "sha256-p8cJZlBYOFmI1NDHXGYjgcv8z9Ldc1amZuYlxxJfeVY=",
"version": "344.0.1"
},
"Librpcsvc": {
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
"version": "31"
},
"Libsystem": {
"hash": "sha256-/NlSwPaoTVx+bl9hYsfz3C5MuLdqGv4vdAh0KDbDKmY=",
"version": "1356"
},
"OpenDirectory": {
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
"version": "146"
},
"Security": {
"hash": "sha256-oxOvZsDoNYZNiWf+MASHrR4Q2o5oaqvK2We51hH7CO8=",
"version": "61901.0.87.0.1"
},
"architecture": {
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
"version": "282"
},
"configd": {
"hash": "sha256-58or+OQP788UgQKO7Y8k8pY/enaSqH971ks7xCPu8fA=",
"version": "1385.0.7"
},
"copyfile": {
"hash": "sha256-I9uDi5BDQKa7mO3XpHxv0d6PiROW2ueZ3vGfrsG0OJo=",
"version": "230.0.1.0.1"
},
"dtrace": {
"hash": "sha256-5HpH6Cg8vWWzOX5ADD//izKDvqGnzV05Giju8lmGeyA=",
"version": "413"
},
"dyld": {
"hash": "sha256-jzoFLwbms0rUwzyjYif/r6Rmr4kyn+as/bhc4paEPeY=",
"version": "1323.3"
},
"eap8021x": {
"hash": "sha256-17bseWT4OWMA8hF+YSDDjxhVyJpbpP2xwv8dGti1YoM=",
"version": "368.0.3"
},
"hfs": {
"hash": "sha256-OkgqZ03gwn2hTuHxZrPDmQOrY4Dwu7MrX+BfG+PTgvE=",
"version": "704.0.3.0.2"
},
"launchd": {
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
"version": "842.1.4"
},
"libclosure": {
"hash": "sha256-pvwfcbeEJmTEPdt6/lgVswiabLRG+sMN6VT5FwG7C4Q=",
"version": "96"
},
"libdispatch": {
"hash": "sha256-L0+Ho9dAlMXVpqFEGIcIMsJc0gULckRulUImNEZe5MU=",
"version": "1542.0.4"
},
"libmalloc": {
"hash": "sha256-482hgm1ESr3LWC/JhuQNGNu9smsa2Eap49/eH+YNAio=",
"version": "792.1.1"
},
"libplatform": {
"hash": "sha256-wGZ2Im81mRXx6epgj/tbOJpg89CEbAr0Z8oFEpkyNMU=",
"version": "359.1.2"
},
"libpthread": {
"hash": "sha256-VuMpQjxuMsdHsFq0q6QIWSWi88gVF2jNzIfti20Gkbw=",
"version": "539"
},
"mDNSResponder": {
"hash": "sha256-iRqCpPAQDRjgRbRz3s6q2oyzq6xo+w4FTBai79104Zo=",
"version": "2881.0.25"
},
"objc4": {
"hash": "sha256-Nlgr36yLvGkUJIEFQ5w8FAB0r2syEsRTw0KuUShNT8E=",
"version": "950"
},
"ppp": {
"hash": "sha256-FzHZ05o7JxwgTqz0e3D68b/DiLu2x2ErzGMh0U78fLo=",
"version": "1020.1.1"
},
"removefile": {
"hash": "sha256-Z5UD0mk/s80CQB0PZWDzSl2JWXmnVmwUvlNb28+hR3k=",
"version": "84"
},
"xnu": {
"hash": "sha256-Cuf7kPtsn4CPXqyZmxVsJlA5i+Ikryp8ezJyGrvT63c=",
"version": "12377.1.9"
}
}
}

View File

@@ -0,0 +1,533 @@
[
{
"package": "apache",
"headers": [
"apache2"
]
},
{
"package": "apr",
"headers": [
"apr-1"
],
"libraries": [
"libapr-1.*",
"libaprutil-1.*"
]
},
{
"package": "boringssl",
"libraries": [
"libboringssl.*"
]
},
{
"package": "bzip2",
"headers": [
"bzlib.h"
],
"libraries": [
"libbz2.*"
]
},
{
"package": "corecrypto",
"libraries": [
"system/libcorecrypto*"
]
},
{
"package": "Csu",
"libraries": [
"*.o"
]
},
{
"package": "cups",
"headers": [
"cups"
],
"libraries": [
"libcups*"
]
},
{
"package": "curl",
"headers": [
"curl"
],
"libraries": [
"libcurl.*"
]
},
{
"package": "cyrus_sasl",
"headers": [
"sasl"
],
"libraries": [
"libsasl*"
]
},
{
"package": "editline",
"headers": [
"editline.h",
"editline"
],
"libraries": [
"libedit.*",
"libeditline.*"
]
},
{
"package": "html-tidy",
"headers": [
"tidy*"
],
"libraries": [
"libtidy.*"
]
},
{
"package": "hunspell",
"headers": [
"hunspell"
],
"libraries": [
"libhunspell*"
]
},
{
"package": "icu",
"headers": [
"unicode"
],
"libraries": [
"libicucore.*"
]
},
{
"package": "libarchive",
"headers": [
"archive.h",
"archive_entry.h"
],
"libraries": [
"libarchive.*"
]
},
{
"package": "libc++",
"headers": [
"c++",
"cxxabi.h",
"__cxxabi_config.h"
],
"libraries": [
"libc++*"
]
},
{
"package": "ld64",
"libraries": [
"libcodedirectory.*",
"libcodedirectory_static.*"
]
},
{
"package": "expat",
"headers": [
"expat.h",
"expat_config.h",
"expat_external.h"
],
"libraries": [
"libexpat.*"
]
},
{
"package": "libffi",
"headers": [
"ffi*"
],
"libraries": [
"libffi*"
]
},
{
"package": "libgcc",
"libraries": [
"libgcc*"
]
},
{
"package": "libiconv",
"headers": [
"iconv.h",
"libcharset.h",
"localcharset.h"
],
"libraries": [
"libcharset.*",
"libiconv.*",
"i18n"
]
},
{
"package": "libiodbc",
"libraries": [
"libiodbc*"
]
},
{
"package": "libkrb4",
"libraries": [
"libkrb4.*"
]
},
{
"package": "libkrb5",
"headers": [
"com_err.h",
"gssapi",
"gssapi.h",
"gssrpc",
"kadm5",
"kdb.h",
"krad.h",
"krb5",
"krb5.h",
"profile.h",
"verto-module.h",
"verto.h"
],
"libraries": [
"krb5",
"libcom_err.*",
"libgssapi_krb5.*",
"libgssrpc.*",
"libk5crypto.*",
"libkadm5clnt.*",
"libkadm5clnt_mit.*",
"libkadm5srv.*",
"libkadm5srv_mit.*",
"libkdb5.*",
"libkrad.*",
"libkrb5*",
"libkrb5support.*",
"libverto.*"
]
},
{
"package": "libpcap",
"headers": [
"pcap*"
],
"libraries": [
"libpcap.*"
]
},
{
"package": "libresolv",
"headers": [
"arpa/nameser.h",
"arpa/nameser_compat.h",
"dns.h",
"dns_util.h",
"nameser.h",
"resolv.h"
],
"libraries": [
"libresolv.*"
]
},
{
"package": "libstdc++",
"libraries": [
"libstdc++.*"
]
},
{
"package": "libsbuf",
"headers": [
"usbuf.h"
],
"libraries": [
"libsbuf.*"
]
},
{
"package": "libtermcap",
"headers": [
"termcap.h"
],
"libraries": [
"libtermcap.*"
]
},
{
"package": "libutil",
"headers": [
"libutil.h"
],
"libraries": [
"libutil.*",
"libutil1.*"
]
},
{
"package": "libxml2",
"headers": [
"libxml",
"libxml2"
],
"libraries": [
"libxml2.*"
]
},
{
"package": "libxo",
"headers": [
"libxo"
],
"libraries": [
"libxo.*"
]
},
{
"package": "libxslt",
"headers": [
"libexslt",
"libxslt"
],
"libraries": [
"libexslt.*",
"libxslt.*"
]
},
{
"package": "liby",
"libraries": [
"liby.a"
]
},
{
"package": "marisa-trie",
"libraries": [
"libmarisa.*"
]
},
{
"package": "ncurses",
"headers": [
"curses*",
"cursslk.h",
"eti.h",
"etip.h",
"form.h",
"menu.h",
"nc_tparm.h",
"ncurses*",
"panel.h",
"term.h",
"term_entry.h",
"termcap.h",
"tic.h",
"unctrl.h"
],
"libraries": [
"libcurses.*",
"libform.*",
"libformw.*",
"libmenu.*",
"libmenuw.*",
"libncurses.*",
"libncursesw.*",
"libpanel.*",
"libpanelw.*",
"libtinfo.*"
]
},
{
"package": "net-snmp",
"headers": [
"net-snmp"
],
"libraries": [
"libnetsnmp*"
]
},
{
"package": "nghttp",
"libraries": [
"lib*nghttp2.*"
]
},
{
"package": "openblas",
"headers": [
"cblas.h",
"f77blas.h",
"lapack.h",
"lapacke.h",
"lapacke_config.h",
"lapacke_mangling.h",
"lapacke_utils.h",
"openblas_config.h"
],
"libraries": [
"libblas.*",
"libcblas.*",
"libclapack.*",
"libf77lapack.*",
"liblapack.*",
"liblapacke.*",
"libopenblas.*",
"libopenblas.*",
"libopenblasp*"
]
},
{
"package": "openldap",
"libraries": [
"liblber.*",
"liblber_r.*",
"libldap.*",
"libldap_r.*"
]
},
{
"package": "openpam",
"headers": [
"security"
],
"libraries": [
"libpam.*",
"pam_*"
]
},
{
"package": "pcre",
"headers": [
"pcre.h",
"pcreposix.h"
],
"libraries": [
"libpcre.*",
"libpcre2*",
"libpcreposix.*"
]
},
{
"package": "php",
"headers": [
"php"
],
"libraries": [
"php"
]
},
{
"package": "postgresql",
"libraries": [
"libecpg*",
"libpg*",
"libpq*"
]
},
{
"package": "python",
"headers": [
"python*"
],
"frameworks": [
"Python.framework"
],
"libraries": [
"libpython*",
"python*"
]
},
{
"package": "readline",
"headers": [
"readline"
],
"libraries": [
"libhistory.*",
"libreadline.*"
]
},
{
"package": "ruby",
"frameworks": [
"Ruby.framework"
],
"libraries": [
"libruby.*",
"ruby"
]
},
{
"package": "sqlite3",
"headers": [
"sqlite3.h",
"sqlite3ext.h"
],
"libraries": [
"libsqlite3.*"
]
},
{
"package": "swift",
"libraries": [
"swift/shims"
]
},
{
"package": "tcl",
"headers": [
"tcl*",
"tk*"
],
"frameworks": [
"Tcl.framework",
"Tk.framework"
],
"libraries": [
"libtcl*",
"libtk*",
"tclConfig.sh",
"tkConfig.sh"
]
},
{
"package": "xar",
"headers": [
"xar"
],
"libraries": [
"libxar.*"
]
},
{
"package": "xz",
"headers": [
"lzma*"
],
"libraries": [
"liblzma.*"
]
},
{
"package": "zlib",
"headers": [
"zconf.h",
"zlib.h"
],
"libraries": [
"libz.*"
]
}
]

View File

@@ -0,0 +1,26 @@
{
"14": {
"urls": [
"https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg",
"https://web.archive.org/web/20250211001355/https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg"
],
"version": "14.4",
"hash": "sha256-QozDiwY0Czc0g45vPD7G4v4Ra+3DujCJbSads3fJjjM="
},
"15": {
"urls": [
"https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg",
"https://web.archive.org/web/20250530132510/https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg"
],
"version": "15.5",
"hash": "sha256-HBiSJuw1XBUK5R/8Sj65c3rftSEvQl/O9ZZVp/g1Amo="
},
"26": {
"urls": [
"https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg",
"https://web.archive.org/web/20250915230423/https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg"
],
"version": "26.2",
"hash": "sha256-hXRlMieVv0smna5uiWRwq87IWOaPWtAjAldbi+wQXcw="
}
}

110
nix/apple-sdk/package.nix Normal file
View File

@@ -0,0 +1,110 @@
let
sdkVersions = builtins.fromJSON (builtins.readFile ./metadata/versions.json);
in
{ lib
, stdenv
, stdenvNoCC
, substitute
, # Specifies the major version used for the SDK. Uses `hostPlatform.darwinSdkVersion` by default.
darwinSdkMajorVersion ? lib.versions.major stdenv.hostPlatform.darwinSdkVersion
, # Enabling bootstrap disables propagation. Defaults to `false` (meaning to propagate certain packages and `xcrun`)
# except in stage0 of the Darwin stdenv bootstrap.
enableBootstrap ? stdenv.name == "bootstrap-stage0-stdenv-darwin"
, # Required by various phases
callPackage
,
}:
let
sdkInfo =
sdkVersions.${darwinSdkMajorVersion}
or (lib.throw "Unsupported SDK major version: ${darwinSdkMajorVersion}");
sdkVersion = sdkInfo.version;
fetchSDK = callPackage ./common/fetch-sdk.nix { };
phases = lib.composeManyExtensions (
[
(callPackage ./common/add-core-symbolication.nix { })
(callPackage ./common/derivation-options.nix { })
(callPackage ./common/passthru-private-frameworks.nix { inherit sdkVersion; })
(callPackage ./common/passthru-source-release-files.nix { inherit sdkVersion; })
(callPackage ./common/remove-disallowed-packages.nix { })
(callPackage ./common/process-stubs.nix { })
]
# Avoid infinite recursions by not propagating certain packages, so they can themselves build with the SDK.
++ lib.optionals (!enableBootstrap) [
(callPackage ./common/propagate-inputs.nix { })
(callPackage ./common/propagate-xcrun.nix { inherit sdkVersion; })
]
# This has to happen last.
++ [
(callPackage ./common/run-build-phase-hooks.nix { })
]
);
in
stdenvNoCC.mkDerivation (
lib.extends phases (finalAttrs: {
pname = "apple-sdk";
inherit (sdkInfo) version;
src = fetchSDK sdkInfo;
dontConfigure = true;
strictDeps = true;
setupHooks = [
# `role.bash` is copied from `../build-support/setup-hooks/role.bash` due to the requirements not to reference
# paths outside the package when it is in `by-name`. It needs to be kept in sync, but it fortunately does not
# change often. Once `build-support` is available as a package (or some other mechanism), it should be changed
# to whatever that replacement is.
./setup-hooks/role.bash
(substitute {
src = ./setup-hooks/sdk-hook.sh;
substitutions = [
"--subst-var-by"
"sdkVersion"
(lib.escapeShellArgs (lib.splitVersion sdkVersion))
];
})
];
installPhase =
let
sdkName = "MacOSX${lib.versions.majorMinor sdkVersion}.sdk";
sdkMajor = lib.versions.major sdkVersion;
in
''
runHook preInstall
mkdir -p "$sdkpath"
cp -rd . "$sdkpath/${sdkName}"
ln -s "${sdkName}" "$sdkpath/MacOSX${sdkMajor}.sdk"
ln -s "${sdkName}" "$sdkpath/MacOSX.sdk"
# Swift adds these locations to its search paths. Avoid spurious warnings by making sure they exist.
mkdir -p "$platformPath/Developer/Library/Frameworks"
mkdir -p "$platformPath/Developer/Library/PrivateFrameworks"
mkdir -p "$platformPath/Developer/usr/lib"
runHook postInstall
'';
passthru = {
sdkroot = finalAttrs.finalPackage + "/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk";
};
__structuredAttrs = true;
meta = {
description = "Frameworks and libraries required for building packages on Darwin";
homepage = "https://developer.apple.com";
teams = [ lib.teams.darwin ];
platforms = lib.platforms.darwin;
badPlatforms = [ lib.systems.inspect.patterns.is32bit ];
};
})
)

View File

@@ -0,0 +1,48 @@
From 6531da946949a94643e6d8424236174ae64fe0ca Mon Sep 17 00:00:00 2001
From: Randy Eckenrode <randy@largeandhighquality.com>
Date: Sat, 30 Sep 2023 18:02:39 -0400
Subject: [PATCH 1/2] Add function definitions needed to build zlog in
system_cmds
---
CoreSymbolication.h | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/CoreSymbolication.h b/CoreSymbolication.h
index a413860..f3cf63f 100644
--- a/CoreSymbolication.h
+++ b/CoreSymbolication.h
@@ -324,7 +324,9 @@ CSSymbolOwnerEditRelocations
CSSymbolOwnerForeachRegion
CSSymbolOwnerForeachRegionWithName
CSSymbolOwnerForeachSection
-CSSymbolOwnerForeachSegment
+*/
+void CSSymbolOwnerForeachSegment(CSSymbolOwnerRef owner, void (^block)(CSSegmentRef));
+/*
CSSymbolOwnerForeachSourceInfo
CSSymbolOwnerForeachSymbol
*/
@@ -333,7 +335,9 @@ void CSSymbolOwnerForeachSymbolWithName(CSSymbolOwnerRef owner, const char *sna
/*
CSSymbolOwnerGetArchitecture
CSSymbolOwnerGetBaseAddress
-CSSymbolOwnerGetCFUUIDBytes
+*/
+const CFUUIDBytes* CSSymbolOwnerGetCFUUIDBytes(CSSymbolOwnerRef owner);
+/*
CSSymbolOwnerGetCompatibilityVersion
CSSymbolOwnerGetCurrentVersion
CSSymbolOwnerGetDataFlags
@@ -390,7 +394,7 @@ CSSymbolOwnerSetLoadTimestamp
CSSymbolOwnerSetPath
CSSymbolOwnerSetRelocationCount
*/
-CSSymbolOwnerSetTransientUserData(CSSymbolOwnerRef owner, uint32_t gen);
+void CSSymbolOwnerSetTransientUserData(CSSymbolOwnerRef owner, uint32_t gen);
/*
CSSymbolOwnerSetUnloadTimestamp
*/
--
2.44.1

View File

@@ -0,0 +1,45 @@
From ae7ac6a7043dbae8e63d6ce5e63dfaf02b5977fe Mon Sep 17 00:00:00 2001
From: Randy Eckenrode <randy@largeandhighquality.com>
Date: Sat, 30 Sep 2023 18:37:18 -0400
Subject: [PATCH 2/2] Add CF_EXPORT To const symbols
---
CoreSymbolication.h | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)
diff --git a/CoreSymbolication.h b/CoreSymbolication.h
index f3cf63f..4124a54 100644
--- a/CoreSymbolication.h
+++ b/CoreSymbolication.h
@@ -49,6 +49,7 @@
#include <CoreFoundation/CoreFoundation.h>
+#include <CoreFoundation/CFBase.h>
#include <mach/mach.h>
@@ -139,13 +140,13 @@ typedef void (^CSSegmentIterator)(CSSegmentRef segment);
* External symbols
*/
-const char* kCSRegionMachHeaderName;
-const CSDictionaryKeyCallBacks kCSTypeDictionaryKeyCallBacks;
-const CSDictionaryValueCallBacks kCSTypeDictionaryValueCallBacks;
-const CSDictionaryKeyCallBacks kCSTypeDictionaryWeakKeyCallBacks;
-const CSDictionaryValueCallBacks kCSTypeDictionaryWeakValueCallBacks;
-const CSSetCallBacks kCSTypeSetCallBacks;
-const CSSetCallBacks kCSTypeSetWeakCallBacks;
+CF_EXPORT const char* kCSRegionMachHeaderName;
+CF_EXPORT const CSDictionaryKeyCallBacks kCSTypeDictionaryKeyCallBacks;
+CF_EXPORT const CSDictionaryValueCallBacks kCSTypeDictionaryValueCallBacks;
+CF_EXPORT const CSDictionaryKeyCallBacks kCSTypeDictionaryWeakKeyCallBacks;
+CF_EXPORT const CSDictionaryValueCallBacks kCSTypeDictionaryWeakValueCallBacks;
+CF_EXPORT const CSSetCallBacks kCSTypeSetCallBacks;
+CF_EXPORT const CSSetCallBacks kCSTypeSetWeakCallBacks;
/*
--
2.44.1

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p coreutils curl file gzip jq xcbuild yq
set -eu -o pipefail
catalog=${1-}
if [ -z "$catalog" ]; then
echo "usage: get-sdks-from-catalog.sh <catalog>"
echo " <catalog> Apple software update catalog (may be gzipped)" >&2
exit 1
fi
scratch=$(mktemp)
trap 'rm -f -- "$scratch"' EXIT
if [[ "$(file "$catalog")" =~ gzip ]]; then
gzcat "$catalog" >"$scratch"
else
cp --reflink=auto "$catalog" "$scratch"
fi
# Grab all SDK packages from the catalog
filter='.Products[].Packages[] | select(.URL | test(".*CLTools_macOSNMOS_SDK.pkg")) | "\(.URL)|\(.MetadataURL)"'
declare -A package_list
for package in $(plutil -convert json -o - "$scratch" | jq -r "$filter"); do
package_list[${package%%|*}]=${package#*|}
done
truncate --size 0 "$scratch"
for pkg in "${!package_list[@]}"; do
ver=$(curl --silent "${package_list[$pkg]}" | xq -r '."pkg-info"."@version"')
echo "{\"url\": \"$pkg\", \"version\": \"$(cut -d. -f1-3 <<<"$ver")\", \"long_version\": \"$ver\"}" >>"$scratch"
done
jq -r --slurp '
group_by(.version | split(".")[0])
| map(max_by(.version))
| sort_by(.version)[]
| "Package URL: \(.url)\n Xcode Ver: \(.version) (\(.long_version))\n"' "$scratch"

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p coreutils curl git gnutar jq moreutils nix
set -eu -o pipefail
if [ ! -v 2 ]; then
echo "usage: lock-sdk-deps.sh <SDK version> <Packages>" >&2
echo " <SDK version> Decimal-separated version number." >&2
echo " Must correspond to a tag in https://github.com/apple-oss-distributions/distribution-macOS" >&2
echo " <Packages> List of packages from the distributions-macOS repository." >&2
echo " Packages not in the repository at the tag for <SDK version> will be ignored."
exit 1
fi
pkgdir=$(dirname "$(dirname "$(realpath "$0")")")
lockfile=$pkgdir/metadata/apple-oss-lockfile.json
if [ ! -e "$lockfile" ]; then
touch "$lockfile"
fi
workdir=$(mktemp -d)
trap 'rm -rf -- "$workdir"' EXIT
sdkVersion=$1
shift
tag="macos-${sdkVersion//./}"
declare -a packages=("$@")
echo "Locking versions for macOS $sdkVersion using tag '$tag'..."
pushd "$workdir" >/dev/null
git clone --branch "$tag" https://github.com/apple-oss-distributions/distribution-macOS.git &>/dev/null
cd distribution-macOS
for package in "${packages[@]}"; do
# If the tag exists in `release.json`, use that as an optimization to avoid downloading unnecessarily from Github.
packageTag=$(jq -r --arg package "$package" '.projects[] | select(.project == $package) | .tag' release.json)
packageCommit=$(git ls-tree -d HEAD "$package" | awk '{print $3}')
if [ ! -d "$package" ]; then
packageCommit=HEAD
fi
# However, sometimes it doesnt exist. In that case, fall back to cloning the repo and check manually
# which tag corresponds to the commit from the submodule.
if [ -z "$packageTag" ]; then
git clone --no-checkout "https://github.com/apple-oss-distributions/$package.git" ../source &>/dev/null
pushd ../source >/dev/null
packageTag=$(git tag --points-at "$packageCommit")
popd >/dev/null
rm -rf ../source
fi
packageVersion=${packageTag##"$package"-}
curl -OL "https://github.com/apple-oss-distributions/$package/archive/$packageTag.tar.gz" &>/dev/null
tar axf "$packageTag.tar.gz"
packageHash=$(nix --extra-experimental-features nix-command hash path "$package-$packageTag")
pkgsjson="{\"$sdkVersion\": {\"$package\": {\"version\": \"$packageVersion\", \"hash\": \"$packageHash\"}}}"
echo " - Locking $package to version $packageVersion with hash '$packageHash'"
jq --argjson pkg "$pkgsjson" -S '. * $pkg' "$lockfile" | sponge "$lockfile"
done
popd >/dev/null

View File

@@ -0,0 +1,62 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p coreutils jq
set -eu -o pipefail
pkgdir=$(dirname "$(dirname "$(realpath "$0")")")
echo '{}' >"$pkgdir/metadata/apple-oss-lockfile.json"
declare -a versions
readarray -t versions < <(jq -r '.[].version' "$pkgdir/metadata/versions.json")
declare -a packages=(
CarbonHeaders
CommonCrypto
IOAudioFamily
IOFireWireFamily
IOFWDVComponents
IOFireWireAVC
IOFireWireSBP2
IOFireWireSerialBusProtocolTransport
IOGraphics
IOHIDFamily
IONetworkingFamily
IOSerialFamily
IOStorageFamily
IOBDStorageFamily
IOCDStorageFamily
IODVDStorageFamily
IOUSBFamily
IOKitUser
Libc
Libinfo
Libm
Libnotify
Librpcsvc
Libsystem
OpenDirectory
Security
architecture
configd
copyfile
dtrace
dyld
eap8021x
hfs
launchd
libclosure
libdispatch
libmalloc
libplatform
libpthread
mDNSResponder
objc4
ppp
removefile
xnu
)
for version in "${versions[@]}"; do
"$pkgdir/scripts/lock-sdk-deps.sh" "$version" "${packages[@]}"
done

View File

@@ -0,0 +1,6 @@
function enablePrivateFrameworks() {
export NIX_CFLAGS_COMPILE+=" -iframework $DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/PrivateFrameworks"
export NIX_LDFLAGS+=" -F$DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/PrivateFrameworks"
}
preConfigureHooks+=(enablePrivateFrameworks)

View File

@@ -0,0 +1,71 @@
# Since the same derivation can be depended on in multiple ways, we need to
# accumulate *each* role (i.e. host and target platforms relative the depending
# derivation) in which the derivation is used.
#
# The role is intended to be used as part of other variables names like
# - $NIX_SOMETHING${role_post}
function getRole() {
case $1 in
-1)
role_post='_FOR_BUILD'
;;
0)
role_post=''
;;
1)
role_post='_FOR_TARGET'
;;
*)
echo "@name@: used as improper sort of dependency" >&2
return 1
;;
esac
}
# `hostOffset` describes how the host platform of the package is slid relative
# to the depending package. `targetOffset` likewise describes the target
# platform of the package. Both are brought into scope of the setup hook defined
# for dependency whose setup hook is being processed relative to the package
# being built.
function getHostRole() {
getRole "$hostOffset"
}
function getTargetRole() {
getRole "$targetOffset"
}
# `depHostOffset` describes how the host platform of the dependencies are slid
# relative to the depending package. `depTargetOffset` likewise describes the
# target platform of dependenices. Both are brought into scope of the
# environment hook defined for the dependency being applied relative to the
# package being built.
function getHostRoleEnvHook() {
getRole "$depHostOffset"
}
function getTargetRoleEnvHook() {
getRole "$depTargetOffset"
}
# This variant is intended specifically for code-producing tool wrapper scripts
# `NIX_@wrapperName@_TARGET_*_@suffixSalt@` tracks this (needs to be an exported
# env var so can't use fancier data structures).
function getTargetRoleWrapper() {
case $targetOffset in
-1)
export NIX_@wrapperName@_TARGET_BUILD_@suffixSalt@=1
;;
0)
export NIX_@wrapperName@_TARGET_HOST_@suffixSalt@=1
;;
1)
export NIX_@wrapperName@_TARGET_TARGET_@suffixSalt@=1
;;
*)
echo "@name@: used as improper sort of dependency" >&2
return 1
;;
esac
}

View File

@@ -0,0 +1,17 @@
local role_post
getHostRole
local sdkVersionVar=NIX_APPLE_SDK_VERSION${role_post}
local developerDirVar=DEVELOPER_DIR${role_post}
local sdkVersionArr=(@sdkVersion@)
local sdkVersion
sdkVersion=$(printf "%02d%02d%02d" "${sdkVersionArr[0]-0}" "${sdkVersionArr[1]-0}" "${sdkVersionArr[2]-0}")
if [ "$sdkVersion" -gt "${!sdkVersionVar-000000}" ]; then
export "$developerDirVar"='@out@'
export "$sdkVersionVar"="$sdkVersion"
export "SDKROOT${role_post}"="${!developerDirVar}/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk"
fi
unset -v role_post developerDirVar sdkVersion sdkVersionArr sdkVersionVar

View File

@@ -41,7 +41,7 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
version = let v = "0.30.5"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
@@ -86,6 +86,7 @@ let
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeBool "MLX_BUILD_CPU" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"mlx==0.30.5; sys_platform == 'darwin'",
"mlx[cpu]==0.30.5; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -31,8 +31,6 @@ dependencies = [
]
[project.scripts]
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
# dependencies only required for development
@@ -63,7 +61,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
@@ -105,6 +103,7 @@ root = "src"
# supported platforms for this project
[tool.uv]
required-version = ">=0.8.6"
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",

View File

@@ -59,6 +59,22 @@
}
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;
runtimeInputs = [ exoVenv ];
runtimeEnv = {
EXO_DASHBOARD_DIR = self'.packages.dashboard;
EXO_RESOURCES_DIR = inputs.self + /resources;
};
text = ''exec python ${path} "$@"'';
};
mkSimplePythonScript = name: path: pkgs.writeShellApplication {
inherit name;
runtimeInputs = [ pkgs.python313 ];
text = ''exec python ${path} "$@"'';
};
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
@@ -66,28 +82,30 @@
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
# Create wrapper script
makeWrapper ${exoVenv}/bin/exo $out/bin/exo \
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_RESOURCES_DIR ${inputs.self + /resources} \
${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin
{
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
exo-bench = mkPythonScript "exo-bench" (inputs.self + /bench/exo_bench.py);
} // {
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
${pkgs.ruff}/bin/ruff check ${inputs.self}
touch $out
'';
};

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 45644286500

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 57657697020

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 68899327465

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 89357758772

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 157548627945

View File

@@ -16,6 +16,7 @@ from exo.download.download_utils import (
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
@@ -53,11 +54,10 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
@@ -108,6 +108,13 @@ class DownloadCoordinator:
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads and model_id in self.download_status:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads.pop(model_id).cancel()
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id

View File

@@ -158,6 +158,78 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
async def _build_file_list_from_local_directory(
model_id: ModelId,
recursive: bool = False,
) -> list[FileListEntry] | None:
"""Build a file list from locally existing model files.
We can only figure out the files we need from safetensors index, so
a local directory must contain a *.safetensors.index.json and
safetensors listed there.
"""
model_dir = (await ensure_models_dir()) / model_id.normalize()
if not await aios.path.exists(model_dir):
return None
def _scan() -> list[FileListEntry] | None:
index_files = list(model_dir.glob("**/*.safetensors.index.json"))
if not index_files:
return None
entries_by_path: dict[str, FileListEntry] = {}
if recursive:
for dirpath, _, filenames in os.walk(model_dir):
for filename in filenames:
if filename.endswith(".partial"):
continue
full_path = Path(dirpath) / filename
rel_path = str(full_path.relative_to(model_dir))
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=full_path.stat().st_size,
)
else:
for item in model_dir.iterdir():
if item.is_file() and not item.name.endswith(".partial"):
entries_by_path[item.name] = FileListEntry(
type="file",
path=item.name,
size=item.stat().st_size,
)
# Add expected weight files from index that haven't been downloaded yet
for index_file in index_files:
try:
index_data = ModelSafetensorsIndex.model_validate_json(
index_file.read_text()
)
relative_dir = index_file.parent.relative_to(model_dir)
for filename in set(index_data.weight_map.values()):
rel_path = (
str(relative_dir / filename)
if relative_dir != Path(".")
else filename
)
if rel_path not in entries_by_path:
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=None,
)
except Exception:
continue
return list(entries_by_path.values())
file_list = await asyncio.to_thread(_scan)
if not file_list:
return None
return file_list
_fetched_file_lists_this_session: set[str] = set()
@@ -183,6 +255,14 @@ async def fetch_file_list_with_cache(
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
local_file_list = await _build_file_list_from_local_directory(
model_id, recursive
)
if local_file_list is not None:
logger.warning(
f"No internet and no cached file list for {model_id} - using local file list"
)
return local_file_list
raise FileNotFoundError(
f"No internet connection and no cached file list for {model_id}"
)
@@ -203,10 +283,18 @@ async def fetch_file_list_with_cache(
except Exception as e:
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
f"No internet and no cached file list for {model_id} - using local file list"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
local_file_list = await _build_file_list_from_local_directory(
model_id, recursive
)
if local_file_list is not None:
logger.warning(
f"Failed to fetch file list for {model_id} and no cache exists, "
)
return local_file_list
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
@@ -378,10 +466,14 @@ async def download_file_with_retry(
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
on_connection_lost()
raise e
break
logger.error(
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)

View File

@@ -195,6 +195,10 @@ class ResumableShardDownloader(ShardDownloader):
self, shard: ShardMetadata
) -> RepoDownloadProgress:
_, progress = await download_shard(
shard, self.on_progress_wrapper, skip_download=True
shard,
self.on_progress_wrapper,
skip_download=True,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
return progress

View File

@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.main import Worker
# I marked this as a dataclass as I want trivial constructors.
@dataclass
class Node:
router: Router
@@ -106,6 +105,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
er_send, er_recv = channel[ElectionResult]()
@@ -136,7 +136,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -148,6 +147,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
@@ -188,6 +189,9 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
)
self._tg.start_soon(self.master.run)
elif (

View File

@@ -1320,29 +1320,40 @@ class API:
]
async def run(self):
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
tg.start_soon(self.run_api, shutdown_ev)
try:
await anyio.sleep_forever()
finally:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
self.command_sender.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"
cfg.bind = [f"0.0.0.0:{self.port}"]
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
with anyio.CancelScope(shield=True):
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=lambda: anyio.sleep_forever(),
shutdown_trigger=ev.wait,
)
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -6,6 +6,7 @@ from loguru import logger
from exo.master.placement import (
add_instance_to_placements,
cancel_unnecessary_downloads,
delete_instance,
get_transition_events,
place_instance,
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
@@ -66,12 +68,9 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
# Receiving indexed events from the forwarder to be applied to state
# Ideally these would be WorkerForwarderEvents but type system says no :(
local_event_receiver: Receiver[ForwarderEvent],
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -81,6 +80,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -96,16 +96,18 @@ class Master:
async def run(self):
logger.info("Starting Master")
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
try:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -278,6 +280,14 @@ class Master:
transition_events = get_transition_events(
self.state.instances, placement
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(

View File

@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CancelDownload,
CreateInstance,
DeleteInstance,
DownloadCommand,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -202,3 +208,29 @@ def get_transition_events(
)
return events
def cancel_unnecessary_downloads(
instances: Mapping[InstanceId, Instance],
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> Sequence[DownloadCommand]:
commands: list[DownloadCommand] = []
currently_downloading = [
(k, v.shard_metadata.model_card.model_id)
for k, vs in download_status.items()
for v in vs
if isinstance(v, (DownloadOngoing))
]
active_models = set(
(
node_id,
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
)
for instance in instances.values()
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
)
for pair in currently_downloading:
if pair not in active_models:
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
return commands

View File

@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
CommandId,
ForwarderCommand,
ForwarderDownloadCommand,
PlaceInstance,
TextGeneration,
)
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
download_command_sender=fcds,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -9,6 +9,7 @@ from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
move_on_after,
sleep_forever,
)
from anyio.abc import TaskGroup
@@ -146,18 +147,21 @@ class Router:
async def run(self):
logger.debug("Starting Router")
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
try:
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
with move_on_after(1, shield=True):
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
@@ -166,12 +170,12 @@ class Router:
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
logger.info(f"Subscribed to {topic}")
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
logger.info(f"Unsubscribed from {topic}")
async def _networking_recv(self):
while True:

View File

@@ -86,28 +86,29 @@ class Election:
async def run(self):
logger.info("Starting Election")
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
try:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election finished")
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
finally:
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election shutdown")
async def elect(self, em: ElectionMessage) -> None:
logger.debug(f"Electing: {em}")

View File

@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
Command = (

View File

@@ -3,10 +3,11 @@
from collections.abc import Sequence
from mlx_lm.models.cache import (
ArraysCache,
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]

View File

@@ -194,9 +194,10 @@ class MpReceiver[T]:
raise EndOfStream from None
return item
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
return await to_thread.run_sync(
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.base import (
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
@@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import MiniMaxAttention
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
if TYPE_CHECKING:
from mlx_lm.models.cache import Cache
TimeoutCallback = Callable[[], None]
@@ -503,12 +511,24 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
layer.self_attn.kv_b_proj
)
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
# layer.self_attn.kv_b_proj
# )
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
@@ -624,6 +644,84 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
return y
class WrappedMiniMaxAttention(CustomMlxLayer):
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
super().__init__(layer)
self.group = group
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
cache: "Cache | None" = None,
) -> mx.array:
batch_dim, seq_dim, _ = x.shape
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
queries: mx.array = self._original_layer.q_proj(x)
keys: mx.array = self._original_layer.k_proj(x)
values: mx.array = self._original_layer.v_proj(x)
if getattr(self, "use_qk_norm", False):
q_dim = queries.shape[-1]
k_dim = keys.shape[-1]
n = self.group.size()
qk = mx.concatenate(
[queries, keys], axis=-1
) # (batch_dim, seq_dim, q_dim + k_dim)
qk = mx.distributed.all_gather(
qk, group=self.group
) # (n*batch_dim, seq_dim, q_dim + k_dim)
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
queries = qk[..., :q_dim].reshape(
batch_dim, seq_dim, -1
) # (batch_dim, seq_dim, n * q_dim)
keys = qk[..., q_dim:].reshape(
batch_dim, seq_dim, -1
) # (batch_dim, seq_dim, n * k_dim)
queries = self._original_layer.q_norm(queries)
keys = self._original_layer.k_norm(keys)
# Split back and take this rank's portion
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
queries = queries.reshape(
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
).transpose(0, 2, 1, 3)
keys = keys.reshape(
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
).transpose(0, 2, 1, 3)
values = values.reshape(
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
).transpose(0, 2, 1, 3)
if cache is not None:
queries = self._original_layer.rope(queries, offset=cache.offset)
keys = self._original_layer.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self._original_layer.rope(queries)
keys = self._original_layer.rope(keys)
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self._original_layer.scale, # type: ignore
mask=mask,
)
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
return self._original_layer.o_proj(output)
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -632,7 +730,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
rank = self.group.rank()
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
@@ -643,18 +740,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Shard qk_norm weights if present (must match sharded head count)
if getattr(layer.self_attn, "use_qk_norm", False):
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
@@ -679,18 +769,95 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
layer.self_attn.n_kv_heads //= self.N
if isinstance(layer, Qwen3DecoderLayer):
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
layer.self_attn.k_proj = self.all_to_sharded_linear(
layer.self_attn.k_proj
)
layer.self_attn.v_proj = self.all_to_sharded_linear(
layer.self_attn.v_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
else:
assert isinstance(layer, Qwen3NextDecoderLayer)
if hasattr(layer, "linear_attn"):
linear_attn = layer.linear_attn
linear_attn.in_proj_qkvz = self.all_to_sharded_linear(
linear_attn.in_proj_qkvz
)
linear_attn.in_proj_ba = self.all_to_sharded_linear(
linear_attn.in_proj_ba
)
linear_attn.out_proj = self.sharded_to_all_linear(
linear_attn.out_proj
)
# Shard conv1d: depthwise conv with non-contiguous channel slicing.
# Channel layout is [q(key_dim), k(key_dim), v(value_dim)].
# Each rank takes its head-slice from each of the three sections.
rank = self.group.rank()
key_dim = linear_attn.key_dim
value_dim = linear_attn.value_dim
key_dim_shard = key_dim // self.N
value_dim_shard = value_dim // self.N
q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)
k_idx = mx.arange(
key_dim + rank * key_dim_shard,
key_dim + (rank + 1) * key_dim_shard,
)
v_idx = mx.arange(
2 * key_dim + rank * value_dim_shard,
2 * key_dim + (rank + 1) * value_dim_shard,
)
conv_indices = mx.concatenate([q_idx, k_idx, v_idx])
linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]
new_conv_dim = key_dim_shard * 2 + value_dim_shard
linear_attn.conv1d.groups = new_conv_dim
num_v_shard = linear_attn.num_v_heads // self.N
v_start = rank * num_v_shard
v_end = v_start + num_v_shard
linear_attn.A_log = linear_attn.A_log[v_start:v_end]
linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]
linear_attn.num_k_heads //= self.N
linear_attn.num_v_heads //= self.N
linear_attn.key_dim = (
linear_attn.head_k_dim * linear_attn.num_k_heads
)
linear_attn.value_dim = (
linear_attn.head_v_dim * linear_attn.num_v_heads
)
linear_attn.conv_dim = (
linear_attn.key_dim * 2 + linear_attn.value_dim
)
else:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
layer.self_attn.k_proj = self.all_to_sharded_linear(
layer.self_attn.k_proj
)
layer.self_attn.v_proj = self.all_to_sharded_linear(
layer.self_attn.v_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
@@ -700,6 +867,14 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
self.all_to_sharded_linear_in_place(
layer.mlp.shared_expert.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_expert.down_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group

View File

@@ -1,16 +1,14 @@
import os
from copy import deepcopy
from typing import Any, cast
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
KVCache,
QuantizedKVCache,
RotatingKVCache,
trim_prompt_cache,
)
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
@@ -26,51 +24,119 @@ _MEMORY_THRESHOLD = float(
)
class KVPrefixCache:
class CacheSnapshot:
"""Snapshot of states at a known token position."""
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
self, states: list[RotatingKVCache | ArraysCache | None], token_count: int
):
self.states = states
self.token_count = token_count
def snapshot_ssm_states(cache: KVCacheType) -> CacheSnapshot:
states: list[ArraysCache | RotatingKVCache | None] = []
for c in cache:
if isinstance(c, (ArraysCache, RotatingKVCache)):
states.append(deepcopy(c))
else:
states.append(None)
token_count = cache_length(cache)
return CacheSnapshot(states=states, token_count=token_count)
def _find_nearest_snapshot(
snapshots: list[CacheSnapshot],
target_token_count: int,
) -> CacheSnapshot | None:
best: CacheSnapshot | None = None
for snap in snapshots:
if snap.token_count <= target_token_count and (
best is None or snap.token_count > best.token_count
):
best = snap
return best
def has_non_kv_caches(cache: KVCacheType) -> bool:
"""Check if a cache contains any ArraysCache (SSM) entries."""
return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)
class KVPrefixCache:
def __init__(self, group: mx.distributed.Group | None = None):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._snapshots: list[list[CacheSnapshot] | None] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self._snapshots.clear()
self._last_used.clear()
def add_kv_cache(self, prompt: str, cache: KVCacheType):
def add_kv_cache(
self,
prompt_tokens: mx.array,
cache: KVCacheType,
ssm_snapshots: list[CacheSnapshot] | None = None,
):
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.prompts.append(prompt_tokens)
self.caches.append(deepcopy(cache))
self._snapshots.append(ssm_snapshots)
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
def update_kv_cache(
self,
index: int,
prompt: str,
prompt_tokens: mx.array,
cache: KVCacheType,
snapshots: list[CacheSnapshot] | None,
restore_pos: int,
):
"""Update an existing cache entry in-place."""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts[index] = tokenized_prompt
old_snapshots = self._snapshots[index]
merged: list[CacheSnapshot] = []
if old_snapshots:
merged = [s for s in old_snapshots if s.token_count <= restore_pos]
if snapshots:
merged.extend(snapshots)
self.prompts[index] = prompt_tokens
self.caches[index] = deepcopy(cache)
self._snapshots[index] = merged or None
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
def _get_snapshot(
self, entry_index: int, target_token_count: int
) -> tuple[int, CacheSnapshot | None]:
if not has_non_kv_caches(self.caches[entry_index]):
return target_token_count, None
snapshots = self._snapshots[entry_index]
if not snapshots:
return 0, None
snap = _find_nearest_snapshot(snapshots, target_token_count)
if snap is not None:
return snap.token_count, snap
return 0, None
def get_kv_cache(
self,
model: Model,
prompt: str,
prompt_tokens: mx.array,
) -> tuple[KVCacheType, mx.array, int | None]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
@@ -79,76 +145,71 @@ class KVPrefixCache:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
nearest SSM snapshot position at or before the match point for correctness.
Same for rotating KV Cache.
"""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
max_length = len(tokenized_prompt)
max_length = len(prompt_tokens)
best_snapshot_index, best_snapshot_length = None, 0
best_index: int | None = None
best_length = 0
is_exact = False
# Find best cache
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(tokenized_prompt, cached_prompt)
length = get_prefix_length(prompt_tokens, cached_prompt)
if length > best_length:
best_index, best_length = i, length
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
self._access_counter += 1
self._last_used[i] = self._access_counter
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:], i
is_exact = True
best_index, best_length = i, length
break
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_index is None:
return make_kv_cache(model), prompt_tokens, None
if best_snapshot_index is not None:
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
# For exact match: trim to max_length-1 so remaining has the last token
# For partial match: trim to best_length, remaining has suffix to prefill
# This ensures stream_generate always has at least one token to start with
target = (max_length - 1) if is_exact else best_length
restore_pos, restore_snap = self._get_snapshot(best_index, target)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# No usable snapshot — need fresh cache
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
return make_kv_cache(model), prompt_tokens, None
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
prompt_cache = deepcopy(self.caches[best_index])
cached_length = cache_length(self.caches[best_index])
tokens_to_trim = cached_length - restore_pos
if tokens_to_trim > 0:
trim_cache(prompt_cache, tokens_to_trim, restore_snap)
# Reset cache offset to match trimmed length
for c in prompt_cache:
if hasattr(c, "offset"):
c.offset = restore_pos
self._access_counter += 1
self._last_used[best_snapshot_index] = self._access_counter
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens, best_snapshot_index
self._access_counter += 1
self._last_used[best_index] = self._access_counter
remaining = prompt_tokens[restore_pos:]
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
return prompt_cache, tokenized_prompt, None
return prompt_cache, remaining, best_index
def _evict_if_needed(self):
"""Evict least recently used entries while memory usage is high."""
if len(self.caches) == 0:
return
# Evict LRU entries until below threshold or only one entry left
# Evict LRU entries until below threshold
while (
len(self.caches) > 1
len(self.caches) > 0
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._snapshots.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
@@ -169,6 +230,21 @@ class KVPrefixCache:
return max_pressure
def trim_cache(
cache: KVCacheType,
num_tokens: int,
snapshot: CacheSnapshot | None = None,
) -> None:
for i, c in enumerate(cache):
if isinstance(c, (ArraysCache, RotatingKVCache)):
if snapshot is not None and snapshot.states[i] is not None:
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
else:
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
else:
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
@@ -177,14 +253,14 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(prompt_tokens)
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
return max(getattr(c, "offset", 0) for c in cache)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
@@ -215,7 +291,7 @@ def make_kv_cache(
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
if hasattr(model, "make_cache"):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore

View File

@@ -15,8 +15,3 @@ DEFAULT_TOP_LOGPROBS: int = 5
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
# MTP enables speculative decoding using the model's built-in draft layer
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)

View File

@@ -1,9 +1,10 @@
import time
from typing import Any, Callable, Generator, cast, get_args
from copy import deepcopy
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -23,17 +24,23 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
from exo.worker.engines.mlx.cache import (
CacheSnapshot,
KVPrefixCache,
encode_prompt,
has_non_kv_caches,
make_kv_cache,
snapshot_ssm_states,
)
from exo.worker.engines.mlx.constants import (
DEFAULT_TOP_LOGPROBS,
KV_BITS,
KV_GROUP_SIZE,
MAX_TOKENS,
MTP_ENABLED,
MTP_NUM_DRAFT_TOKENS,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
fix_unmatched_think_end_tokens,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -49,7 +56,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> tuple[float, int]:
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
@@ -60,17 +67,21 @@ def prefill(
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0, 0
return 0.0, 0, []
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
has_ssm = has_non_kv_caches(cache)
snapshots: list[CacheSnapshot] = []
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
elapsed = time.perf_counter() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
if has_ssm:
snapshots.append(snapshot_ssm_states(cache))
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
@@ -87,7 +98,18 @@ def prefill(
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cast(list[Any], cache), 1)
# stream_generate added 1 extra generated token to the cache, so we should trim it.
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
for i, c in enumerate(cache):
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
assert pre_gen is not None
if pre_gen.states[i] is not None:
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
else:
assert not isinstance(c, (ArraysCache, RotatingKVCache))
c.trim(2) # pyright: ignore[reportUnknownMemberType]
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
@@ -95,12 +117,14 @@ def prefill(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec, num_tokens
# Exclude the last snapshot
return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else []
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -119,7 +143,7 @@ def warmup_inference(
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
sampler = make_sampler(temp=0.0)
logger.info("Generating warmup tokens")
for _r in stream_generate(
@@ -138,9 +162,7 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
mx_barrier(group)
return tokens_generated
@@ -163,11 +185,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def _has_mtp_module(model: Model) -> bool:
"""Check if the model has an attached MTP module."""
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
def extract_top_logprobs(
logprobs: mx.array,
tokenizer: TokenizerWrapper,
@@ -228,11 +245,17 @@ def mlx_generate(
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
group: mx.distributed.Group | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
if task.seed is not None:
mx.random.seed(task.seed)
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
seed = task.seed or 42
mx.random.seed(seed)
# Encode prompt once at the top and fix unmatched think tags
all_prompt_tokens = encode_prompt(tokenizer, prompt)
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
# Do not use the prefix cache if we are trying to do benchmarks.
is_bench = task.bench
@@ -244,13 +267,16 @@ def mlx_generate(
matched_index: int | None = None
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
prompt_tokens = all_prompt_tokens
else:
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
model, all_prompt_tokens
)
all_prompt_tokens = encode_prompt(tokenizer, prompt)
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
if prefix_hit_length > 0:
logger.info(
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -264,24 +290,6 @@ def mlx_generate(
top_k=task.top_k if task.top_k is not None else 0,
)
max_tokens = task.max_output_tokens or MAX_TOKENS
# Check if we should use MTP speculative decoding
use_mtp = MTP_ENABLED and _has_mtp_module(model)
if use_mtp:
logger.info("Using MTP speculative decoding")
yield from _mlx_generate_with_mtp(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
)
return
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
@@ -291,13 +299,19 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model,
tokenizer,
sampler,
prompt_tokens[:-1],
caches,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
last_token = prompt_tokens[-2:]
max_tokens = task.max_output_tokens or MAX_TOKENS
accumulated_text = ""
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
@@ -323,7 +337,6 @@ def mlx_generate(
start=1,
):
generated_text_parts.append(out.text)
logger.info(out.text)
accumulated_text += out.text
if think_start is not None and out.text == think_start:
@@ -391,16 +404,6 @@ def mlx_generate(
selected_token=out.token,
)
yield GenerationResponse(
text=text,
token=out.token,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
usage=usage,
)
if is_done:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
@@ -414,79 +417,44 @@ def mlx_generate(
f"{generation_tps:.1f} tok/s"
)
if kv_prefix_cache is not None:
full_prompt = prompt + "".join(generated_text_parts)
generated_tokens_array = mx.array(
tokenizer.encode(
"".join(generated_text_parts), add_special_tokens=False
)
)
full_prompt_tokens = mx.concatenate(
[all_prompt_tokens, generated_tokens_array]
)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
kv_prefix_cache.update_kv_cache(
matched_index,
full_prompt_tokens,
caches,
cache_snapshots,
restore_pos=prefix_hit_length,
)
else:
kv_prefix_cache.add_kv_cache(full_prompt, caches)
kv_prefix_cache.add_kv_cache(
full_prompt_tokens, caches, cache_snapshots
)
yield GenerationResponse(
text=text,
token=out.token,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
usage=usage,
)
if is_done:
mx_barrier(group)
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?
def _mlx_generate_with_mtp(
model: Model,
tokenizer: TokenizerWrapper,
prompt: str,
max_tokens: int,
sampler: Callable[[mx.array], mx.array],
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
prompt_cache: KVCacheType,
) -> Generator[GenerationResponse]:
"""MTP speculative decoding generation path.
Uses the model's attached MTP module for speculative decoding,
which can provide 1.5-2x speedup with ~81% acceptance rate.
"""
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
mtp_module: Any = model.mtp_module # type: ignore[attr-defined]
for out in mtp_speculative_generate(
model=model,
mtp_module=mtp_module,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=cast(list[Any], prompt_cache),
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
kv_bits=KV_BITS,
):
logger.info(f"{out.text} (from_draft={out.from_draft})")
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
usage=None,
)
if out.finish_reason is not None:
break

View File

@@ -1,6 +0,0 @@
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
from exo.worker.engines.mlx.mtp.module import MTPModule
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
__all__ = ["MTPModule", "mtp_speculative_generate"]

View File

@@ -1,207 +0,0 @@
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
The MTP architecture predicts one additional token ahead using:
1. hnorm - RMSNorm for hidden state normalization
2. enorm - RMSNorm for embedding normalization
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
4. transformer_block - Single decoder layer (attention + MLP)
5. Shared embedding/lm_head from main model
Forward pass:
h_norm = hnorm(hidden_state)
e_norm = enorm(embed(token))
projected = eh_proj(concat([h_norm, e_norm]))
new_hidden = transformer_block(projected)
logits = lm_head(output_norm(new_hidden))
"""
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import (
DeepseekV3Attention,
DeepseekV3MLP,
ModelArgs,
)
MTP_LAYER_INDEX = 61
class MTPModule(nn.Module):
"""Multi-Token Prediction module for DeepSeek V3.
This module is initialized from the layer 61 weights that are normally
discarded during model loading. It enables speculative decoding by
predicting one token ahead using the hidden state from the main model.
"""
def __init__(
self,
config: ModelArgs,
shared_embedding: nn.Embedding,
shared_lm_head: nn.Linear,
output_norm: nn.RMSNorm,
) -> None:
super().__init__()
self.config = config
# MTP-specific normalization layers
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Projection: concatenated [hidden, embedding] -> hidden_size
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
# Single transformer block for MTP
# Use a dense MLP since this is just a single layer
self.transformer_block = MTPTransformerBlock(config)
# Share embedding and lm_head with main model
self._shared_embedding = shared_embedding
self._shared_lm_head = shared_lm_head
self._output_norm = output_norm
def __call__(
self,
hidden_state: mx.array,
draft_token: mx.array,
cache: KVCache | None = None,
mask: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass for MTP.
Args:
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
draft_token: Token to embed and combine with hidden state [batch, seq_len]
cache: Optional KV cache for the MTP transformer block
mask: Optional attention mask
Returns:
tuple of (logits, new_hidden_state)
"""
# Get embedding of draft token
embedding = self._shared_embedding(draft_token)
# Normalize hidden state and embedding
h_norm = self.hnorm(hidden_state)
e_norm = self.enorm(embedding)
# Project concatenated representation
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
projected = self.eh_proj(concatenated)
# Pass through single transformer block
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
# Apply output norm and get logits
normed_hidden = self._output_norm(new_hidden)
logits = self._shared_lm_head(normed_hidden)
return logits, new_hidden
class MTPTransformerBlock(nn.Module):
"""Single transformer block for MTP.
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
instead of MoE since this is just for the single MTP layer.
"""
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.self_attn = DeepseekV3Attention(config)
# MTP uses dense MLP, not MoE
self.mlp = DeepseekV3MLP(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
cache: Any | None = None,
) -> mx.array:
"""Forward pass with residual connections."""
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
"""Extract MTP-specific weights from layer 61.
The MTP layer has these weight patterns:
- model.layers.61.enorm.weight -> MTP embedding normalization
- model.layers.61.hnorm.weight -> MTP hidden normalization
- model.layers.61.eh_proj.weight -> MTP projection layer
- model.layers.61.self_attn.* -> MTP attention
- model.layers.61.input_layernorm.* -> MTP layer norms
- model.layers.61.post_attention_layernorm.*
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
Args:
weights: Full model weights dict
Returns:
Dict of MTP-specific weights with keys renamed for MTPModule
"""
mtp_weights: dict[str, mx.array] = {}
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
for key, value in weights.items():
if key.startswith(mtp_prefix):
# Remove the layer prefix to get relative path
new_key = key[len(mtp_prefix) :]
mtp_weights[new_key] = value
return mtp_weights
def load_mtp_weights_into_module(
mtp_module: MTPModule,
mtp_weights: dict[str, mx.array],
) -> None:
"""Load extracted MTP weights into the MTPModule.
Args:
mtp_module: The MTPModule instance to load weights into
mtp_weights: Extracted MTP weights from extract_mtp_weights()
"""
# Map weight names to module attributes
weight_mapping: dict[str, str] = {
"enorm.weight": "enorm.weight",
"hnorm.weight": "hnorm.weight",
"eh_proj.weight": "eh_proj.weight",
}
# Load direct mappings
for src_name, dst_name in weight_mapping.items():
if src_name in mtp_weights:
parts = dst_name.split(".")
obj: Any = mtp_module
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], mtp_weights[src_name])
# Load transformer block weights (self_attn, mlp, layer norms)
transformer_prefixes = [
"self_attn",
"mlp",
"input_layernorm",
"post_attention_layernorm",
]
for prefix in transformer_prefixes:
for key, value in mtp_weights.items():
if key.startswith(prefix):
# Navigate to the correct attribute
parts = key.split(".")
obj = mtp_module.transformer_block
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)

View File

@@ -1,506 +0,0 @@
"""MTP Speculative Decoding for DeepSeek V3.
This module implements speculative decoding using the Multi-Token Prediction (MTP)
layer from DeepSeek V3. The key difference from standard speculative decoding is
that MTP requires hidden states from the main model, not just token predictions.
Based on vLLM/SGLang research:
- 81-82% acceptance rate with k=1
- 1.5-2x speedup at low QPS
"""
import functools
import time
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models import cache
from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.mtp.module import MTPModule
# Generation stream for async operations
generation_stream = mx.new_stream(mx.default_device())
@dataclass
class MTPGenerationResponse:
"""Response from MTP speculative generation.
Attributes:
text: The next segment of decoded text.
token: The next token.
logprobs: A vector of log probabilities.
from_draft: Whether the token was generated by the MTP draft module.
prompt_tokens: The number of tokens in the prompt.
prompt_tps: The prompt processing tokens-per-second.
generation_tokens: The number of generated tokens.
generation_tps: The tokens-per-second for generation.
peak_memory: The peak memory used so far in GB.
finish_reason: The reason the response is being sent: "length", "stop" or None.
"""
text: str
token: int
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: str | None = None
def maybe_quantize_kv_cache(
prompt_cache: list[Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
"""Quantize KV cache entries if needed."""
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized")
and hasattr(c, "offset")
and c.offset >= quantized_kv_start
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
class ModelWithHiddenStates(nn.Module):
"""Wrapper to extract hidden states before lm_head.
This wrapper allows capturing the hidden states from the transformer
layers before the final lm_head projection, which is needed for MTP.
"""
def __init__(self, base_model: nn.Module) -> None:
super().__init__()
self._base = base_model
def forward_with_hidden(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass that returns both logits and hidden states.
Args:
inputs: Input token ids
model_cache: KV cache
Returns:
Tuple of (logits, hidden_states)
"""
# Call the inner model (transformer layers + norm)
hidden: mx.array = self._base.model(inputs, model_cache)
# Get logits from lm_head
logits: mx.array = self._base.lm_head(hidden)
return logits, hidden
def forward(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> mx.array:
"""Standard forward pass returning only logits."""
return cast(mx.array, self._base(inputs, cache=model_cache))
@property
def layers(self) -> list[nn.Module]:
"""Access layers for cache creation."""
return cast(list[nn.Module], self._base.layers)
def mtp_speculative_generate_step(
prompt: mx.array,
model: nn.Module,
mtp_module: MTPModule,
*,
num_draft_tokens: int = 1,
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
mtp_cache: KVCache | None = None,
prefill_step_size: int = 512,
kv_bits: int | None = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[tuple[int, mx.array, bool], None, None]:
"""MTP speculative decoding generator.
Unlike standard speculative decoding where the draft model only needs tokens,
MTP requires the hidden states from the main model. This generator:
1. Runs the main model to get logits AND hidden states
2. Uses MTP module with hidden state + sampled token to predict next token
3. Verifies MTP predictions with the main model
4. Accepts/rejects based on matching
Args:
prompt: The input prompt as token ids
model: The main model (must support return_hidden=True)
mtp_module: The MTP module for draft prediction
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
max_tokens: Maximum number of tokens to generate
sampler: Optional sampler function for token selection
logits_processors: Optional list of logits processors
prompt_cache: KV cache for the main model
mtp_cache: KV cache for the MTP module
prefill_step_size: Step size for prompt processing
kv_bits: Bits for KV cache quantization
kv_group_size: Group size for KV cache quantization
quantized_kv_start: Step to begin cache quantization
Yields:
Tuple of (token, logprobs, from_draft)
"""
y = prompt.astype(mx.uint32)
prev_tokens: mx.array | None = None
# Wrap model to get hidden states
wrapped_model = (
model
if isinstance(model, ModelWithHiddenStates)
else ModelWithHiddenStates(model)
)
# Create caches if needed
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model)
if mtp_cache is None:
mtp_cache = KVCache()
final_sampler = (
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _process_and_sample(
tokens: mx.array | None,
logits: mx.array,
) -> tuple[mx.array, mx.array]:
"""Process logits and sample tokens."""
nonlocal logits_processors
processed_logits = logits
if logits_processors:
for processor in logits_processors:
processed_logits = processor(
tokens if tokens is not None else mx.array([]), processed_logits
)
logprobs = processed_logits - mx.logsumexp(
processed_logits, axis=-1, keepdims=True
)
sampled = final_sampler(logprobs)
return sampled, logprobs
def _main_model_step_with_hidden(
input_y: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Run main model step with hidden state return."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits, hidden = wrapped_model.forward_with_hidden(
input_y[None], prompt_cache
)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
def _main_model_step(
input_y: mx.array,
) -> tuple[mx.array, mx.array]:
"""Run main model step without hidden state."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits = wrapped_model.forward(input_y[None], prompt_cache)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0)
def _mtp_draft(
hidden_state: mx.array,
draft_token: mx.array,
) -> tuple[mx.array, mx.array]:
"""Generate draft token using MTP module."""
with mx.stream(generation_stream):
logits, new_hidden = mtp_module(
hidden_state,
draft_token,
cache=mtp_cache,
)
logits = logits[:, -1, :]
sampled, _ = _process_and_sample(None, logits)
return sampled, new_hidden
def _prefill(input_y: mx.array) -> mx.array:
"""Prefill the prompt cache."""
result_y = input_y
while result_y.size > prefill_step_size:
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
quantize_cache_fn(prompt_cache)
mx.eval([c.state for c in prompt_cache])
result_y = result_y[prefill_step_size:]
mx.clear_cache()
return result_y
def _rewind_cache(num_draft: int, num_accept: int) -> None:
"""Rewind caches after rejection."""
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
# Prefill phase
with mx.stream(generation_stream):
y = _prefill(y)
ntoks = 0
num_draft = 0
n_accepted = 0
last_hidden: mx.array | None = None
try:
# Initial step to get first token and hidden state
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
while ntoks < max_tokens:
# Draft phase: use MTP to predict next token
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
if num_draft > 0 and last_hidden is not None:
# Use MTP to draft
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
mx.eval(draft_token, draft_hidden)
# Verify with main model
# Feed the drafted token to main model
verify_input = mx.concatenate([y, draft_token.flatten()])
verify_sampled, verify_logprobs, new_hidden = (
_main_model_step_with_hidden(verify_input)
)
mx.eval(verify_sampled, verify_logprobs, new_hidden)
# Check if draft matches verification
draft_token_val = int(draft_token.item())
verify_token_val = (
int(verify_sampled[0].item())
if verify_sampled.shape[0] > 1
else int(verify_sampled.item())
)
# Yield the current token (not from draft)
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
if draft_token_val == verify_token_val:
# Draft accepted
n_accepted += 1
ntoks += 1
draft_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
yield draft_token_val, draft_logprobs, True
if ntoks >= max_tokens:
break
# Continue with the token after the draft
y = (
verify_sampled[-1:]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[-1]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = new_hidden
else:
# Draft rejected - rewind and use verified token
_rewind_cache(1, 0)
y = (
verify_sampled[:1]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = (
new_hidden[:, :1, :] if new_hidden is not None else None
)
else:
# No drafting, just do normal generation
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
if ntoks % 256 == 0:
mx.clear_cache()
finally:
_rewind_cache(num_draft, n_accepted)
def mtp_speculative_generate(
model: nn.Module,
mtp_module: MTPModule,
tokenizer: TokenizerWrapper,
prompt: str | mx.array | list[int],
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
num_draft_tokens: int = 1,
prefill_step_size: int = 512,
kv_group_size: int = 64,
kv_bits: int | None = None,
) -> Generator[MTPGenerationResponse, None, None]:
"""High-level MTP speculative generation with text output.
Args:
model: The main model
mtp_module: The MTP module for draft prediction
tokenizer: Tokenizer for encoding/decoding
prompt: Input prompt (string, array, or token list)
max_tokens: Maximum tokens to generate
sampler: Optional sampler function
logits_processors: Optional logits processors
prompt_cache: Optional KV cache
num_draft_tokens: Number of draft tokens
prefill_step_size: Prefill step size
kv_group_size: KV group size
kv_bits: KV bits
Yields:
MTPGenerationResponse objects with text and metadata
"""
if not isinstance(prompt, mx.array):
if isinstance(prompt, str):
bos_token = getattr(tokenizer, "bos_token", None)
add_special_tokens = bos_token is None or not prompt.startswith(
str(bos_token)
)
encoded: list[int] = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
prompt = mx.array(encoded)
else:
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
token_generator = mtp_speculative_generate_step(
prompt,
model,
mtp_module,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
num_draft_tokens=num_draft_tokens,
prefill_step_size=prefill_step_size,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
tic = time.perf_counter()
prompt_tps = 0.0
token = 0
logprobs: mx.array = mx.array([0.0])
from_draft = False
n = 0
for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = float(prompt.size) / prompt_time
tic = time.perf_counter()
if token in eos_token_ids:
break
detokenizer.add_token(token)
if (n + 1) == max_tokens:
break
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason=None,
)
detokenizer.finalize()
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason="stop" if token in eos_token_ids else "length",
)

View File

@@ -1 +0,0 @@
"""Tests for MTP module."""

View File

@@ -1,412 +0,0 @@
"""Unit tests for MTP module components."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.module import (
MTP_LAYER_INDEX,
MTPModule,
MTPTransformerBlock,
extract_mtp_weights,
load_mtp_weights_into_module,
)
class MockModelArgs:
"""Mock ModelArgs for testing without importing deepseek_v3."""
def __init__(
self,
hidden_size: int = 256,
intermediate_size: int = 512,
num_attention_heads: int = 4,
num_key_value_heads: int = 4,
rms_norm_eps: float = 1e-6,
vocab_size: int = 1000,
q_lora_rank: int | None = None,
kv_lora_rank: int = 64,
qk_rope_head_dim: int = 16,
v_head_dim: int = 32,
qk_nope_head_dim: int = 32,
rope_theta: float = 10000.0,
rope_scaling: dict | None = None,
attention_bias: bool = False,
max_position_embeddings: int = 2048,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.vocab_size = vocab_size
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.max_position_embeddings = max_position_embeddings
class TestExtractMTPWeights:
"""Tests for extract_mtp_weights function."""
def test_extracts_layer_61_weights(self) -> None:
"""Should extract only layer 61 weights."""
weights = {
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
"model.layers.61.enorm.weight": mx.ones((10,)),
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
"model.embed_tokens.weight": mx.zeros((100, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 3
assert "enorm.weight" in mtp_weights
assert "hnorm.weight" in mtp_weights
assert "eh_proj.weight" in mtp_weights
# Check values are preserved
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
def test_returns_empty_dict_when_no_layer_61(self) -> None:
"""Should return empty dict when layer 61 doesn't exist."""
weights = {
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 0
def test_handles_nested_layer_61_weights(self) -> None:
"""Should handle nested weight paths like self_attn.q_proj.weight."""
weights = {
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
(10, 10)
),
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert "self_attn.q_a_proj.weight" in mtp_weights
assert "mlp.gate_proj.weight" in mtp_weights
class TestMTPTransformerBlock:
"""Tests for MTPTransformerBlock."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64, intermediate_size=128, num_attention_heads=2
)
def test_forward_shape(self, config: MockModelArgs) -> None:
"""Forward pass should preserve input shape."""
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
output = block(x)
assert output.shape == x.shape
def test_forward_with_mask(self, config: MockModelArgs) -> None:
"""Forward pass should work with attention mask."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
# Create causal mask
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
output = block(x, mask=mask)
assert output.shape == x.shape
class TestMTPModule:
"""Tests for MTPModule."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
@pytest.fixture
def shared_components(
self, config: MockModelArgs
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
return embedding, lm_head, output_norm
def test_initialization(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should initialize with correct components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
assert mtp.hnorm is not None
assert mtp.enorm is not None
assert mtp.eh_proj is not None
assert mtp.transformer_block is not None
def test_forward_output_shapes(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""Forward pass should return correct output shapes."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
batch_size = 2
seq_len = 1
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
logits, new_hidden = mtp(hidden_state, draft_token)
assert logits.shape == (batch_size, seq_len, config.vocab_size)
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
def test_shares_embedding_and_lm_head(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should use shared embedding and lm_head."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Verify they're the same objects
assert mtp._shared_embedding is embedding
assert mtp._shared_lm_head is lm_head
assert mtp._output_norm is output_norm
class TestLoadMTPWeights:
"""Tests for load_mtp_weights_into_module."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
"""Should load enorm and hnorm weights."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Create test weights
test_enorm = mx.ones((config.hidden_size,)) * 3.0
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
mtp_weights = {
"enorm.weight": test_enorm,
"hnorm.weight": test_hnorm,
}
load_mtp_weights_into_module(mtp, mtp_weights)
assert mx.allclose(mtp.enorm.weight, test_enorm)
assert mx.allclose(mtp.hnorm.weight, test_hnorm)
class TestSanitizePatch:
"""Tests for the sanitize patching logic."""
def test_patch_preserves_layer_61(self) -> None:
"""Patching sanitize should preserve layer 61 weights."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
# Get original sanitize behavior
original_sanitize = model_cls.sanitize
try:
# Apply patch
_patch_deepseek_sanitize_for_mtp()
# Note: we can't easily test the full sanitize without a real model
# This test verifies the patch is applied
assert model_cls.sanitize is not original_sanitize
finally:
_restore_deepseek_sanitize()
# Verify restore worked
assert model_cls.sanitize is original_sanitize
def test_restore_sanitize(self) -> None:
"""Restoring sanitize should return to original behavior."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is not original_sanitize
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
def test_double_patch_is_safe(self) -> None:
"""Calling patch twice should be safe (idempotent)."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
try:
_patch_deepseek_sanitize_for_mtp()
patched_sanitize = model_cls.sanitize
# Patch again - should be no-op
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is patched_sanitize
finally:
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
class TestModelIdDetection:
"""Tests for DeepSeek V3 model ID detection."""
def test_detects_deepseek_v3(self) -> None:
"""Should detect DeepSeek V3 model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
def test_detects_deepseek_r1(self) -> None:
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
def test_rejects_non_deepseek(self) -> None:
"""Should reject non-DeepSeek model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
def test_case_insensitive(self) -> None:
"""Detection should be case insensitive."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
class TestFlattenParams:
"""Tests for parameter flattening utility."""
def test_flattens_nested_dict(self) -> None:
"""Should flatten nested parameter dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"model": {
"layers": {
"0": {
"weight": mx.zeros((10,)),
}
},
"embed": mx.ones((5,)),
}
}
flat = _flatten_params(params)
assert "model.layers.0.weight" in flat
assert "model.embed" in flat
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
def test_handles_flat_dict(self) -> None:
"""Should handle already-flat dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"weight": mx.zeros((10,)),
"bias": mx.ones((10,)),
}
flat = _flatten_params(params)
assert flat == params

View File

@@ -1,253 +0,0 @@
"""Unit tests for MTP speculative decoding."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.speculative_decode import (
ModelWithHiddenStates,
maybe_quantize_kv_cache,
)
class MockModel(nn.Module):
"""Mock model for testing speculative decoding."""
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# Create simple model components
self.model = MockInnerModel(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
hidden = self.model(inputs, cache)
return self.lm_head(hidden)
@property
def layers(self) -> list[nn.Module]:
return self._layers
class MockInnerModel(nn.Module):
"""Mock inner model (like DeepseekV3Model)."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.embed_tokens = nn.Embedding(100, hidden_size)
self.norm = nn.RMSNorm(hidden_size)
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
# Simple embedding + norm
embedded = self.embed_tokens(inputs)
return self.norm(embedded)
class TestModelWithHiddenStates:
"""Tests for ModelWithHiddenStates wrapper."""
@pytest.fixture
def mock_model(self) -> MockModel:
return MockModel(hidden_size=64, vocab_size=100)
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
"""Standard forward should return logits."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits = wrapped.forward(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
"""Forward with hidden should return (logits, hidden)."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits, hidden = wrapped.forward_with_hidden(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
assert hidden.shape == (1, 3, mock_model.hidden_size)
def test_layers_property(self, mock_model: MockModel) -> None:
"""Should expose layers property from base model."""
wrapped = ModelWithHiddenStates(mock_model)
assert wrapped.layers == mock_model.layers
assert len(wrapped.layers) == 3
class TestMaybeQuantizeKVCache:
"""Tests for KV cache quantization."""
def test_no_quantization_when_bits_none(self) -> None:
"""Should not quantize when kv_bits is None."""
cache = [MockCache(offset=100)]
maybe_quantize_kv_cache(
cache,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=None,
)
# Cache should be unchanged
assert not hasattr(cache[0], "quantized")
def test_respects_quantized_kv_start(self) -> None:
"""Should only quantize caches past the start threshold."""
cache_below = MockCache(offset=30)
cache_above = MockCache(offset=100)
caches = [cache_below, cache_above]
maybe_quantize_kv_cache(
caches,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=4,
)
# Only cache_above should be quantized
assert not getattr(cache_below, "was_quantized", False)
assert getattr(caches[1], "was_quantized", False)
class MockCache:
"""Mock KV cache for testing."""
def __init__(self, offset: int = 0) -> None:
self.offset = offset
self.was_quantized = False
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
quantized = MockCache(self.offset)
quantized.was_quantized = True
return quantized
class TestSpeculativeDecodingLogic:
"""Tests for the core speculative decoding logic."""
def test_draft_acceptance_identical_tokens(self) -> None:
"""When draft matches verification, both should be accepted."""
# This tests the logic, not the full generator
draft_token = 42
verify_token = 42
accepted = draft_token == verify_token
assert accepted
def test_draft_rejection_different_tokens(self) -> None:
"""When draft differs from verification, draft should be rejected."""
draft_token = 42
verify_token = 99
accepted = draft_token == verify_token
assert not accepted
class TestMTPGenerationResponse:
"""Tests for MTPGenerationResponse dataclass."""
def test_response_creation(self) -> None:
"""Should create response with all fields."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="Hello",
token=42,
logprobs=mx.array([0.1, 0.2]),
from_draft=True,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=5,
generation_tps=50.0,
peak_memory=1.5,
finish_reason=None,
)
assert response.text == "Hello"
assert response.token == 42
assert response.from_draft is True
assert response.finish_reason is None
def test_response_with_finish_reason(self) -> None:
"""Should handle finish_reason."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="",
token=0,
logprobs=mx.array([0.0]),
from_draft=False,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=100,
generation_tps=50.0,
peak_memory=1.5,
finish_reason="length",
)
assert response.finish_reason == "length"
class TestIntegration:
"""Integration tests for the full MTP pipeline."""
def test_mtp_module_with_mock_model(self) -> None:
"""Test MTP module can be created and run with mock components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
from exo.worker.engines.mlx.mtp.module import MTPModule
# Create mock config
class MockConfig:
hidden_size = 64
intermediate_size = 128
num_attention_heads = 2
num_key_value_heads = 2
rms_norm_eps = 1e-6
q_lora_rank = None
kv_lora_rank = 32
qk_rope_head_dim = 8
v_head_dim = 16
qk_nope_head_dim = 16
rope_theta = 10000.0
rope_scaling = None
attention_bias = False
max_position_embeddings = 2048
config = MockConfig()
embedding = nn.Embedding(100, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
output_norm = nn.RMSNorm(config.hidden_size)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Run forward pass
hidden = mx.random.normal((1, 1, config.hidden_size))
token = mx.array([[5]])
logits, new_hidden = mtp(hidden, token)
assert logits.shape == (1, 1, 100)
assert new_hidden.shape == (1, 1, config.hidden_size)
# Verify outputs are valid (not NaN)
assert not mx.any(mx.isnan(logits))
assert not mx.any(mx.isnan(new_hidden))

View File

@@ -3,7 +3,6 @@ import os
import resource
import sys
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -25,7 +24,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
MTP_ENABLED,
TRUST_REMOTE_CODE,
)
@@ -71,142 +69,6 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
# MTP (Multi-Token Prediction) support for DeepSeek V3
MTP_LAYER_INDEX = 61
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
def _is_deepseek_v3_model(model: nn.Module) -> bool:
"""Check if the model is DeepSeek V3."""
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model)
def _might_be_deepseek_v3(model_id: str) -> bool:
"""Check if model ID suggests this might be DeepSeek V3."""
model_id_lower = model_id.lower()
return "deepseek" in model_id_lower and (
"v3" in model_id_lower or "r1" in model_id_lower
)
def _patch_deepseek_sanitize_for_mtp() -> None:
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights."""
global _original_deepseek_sanitize
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
if _original_deepseek_sanitize is not None:
return
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
def sanitize_with_mtp(
self: DeepSeekV3Model, weights: dict[str, Any]
) -> dict[str, Any]:
if _original_deepseek_sanitize is None:
raise RuntimeError(
"_original_deepseek_sanitize is None - patch not applied correctly"
)
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
mtp_weights = {
k: v
for k, v in weights.items()
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
}
return {**original_result, **mtp_weights}
DeepSeekV3Model.sanitize = sanitize_with_mtp
def _restore_deepseek_sanitize() -> None:
"""Restore the original DeepSeek V3 sanitize method."""
global _original_deepseek_sanitize
if _original_deepseek_sanitize is None:
return
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
_original_deepseek_sanitize = None
def _flatten_params(
params: dict[str, Any],
prefix: str = "",
) -> dict[str, mx.array]:
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
result: dict[str, mx.array] = {}
for key, value in params.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, mx.array):
result[full_key] = value
elif isinstance(value, dict):
result.update(_flatten_params(value, full_key))
return result
def _extract_mtp_module(model: nn.Module) -> Any | None:
"""Extract MTP module from a loaded DeepSeek V3 model."""
from exo.worker.engines.mlx.mtp.module import (
MTPModule,
extract_mtp_weights,
load_mtp_weights_into_module,
)
try:
inner_model = getattr(model, "model", None)
if inner_model is None or not hasattr(inner_model, "layers"):
logger.debug("Model doesn't have expected structure for MTP extraction")
return None
layers: list[nn.Module] = inner_model.layers # type: ignore[assignment]
if len(layers) <= MTP_LAYER_INDEX:
logger.debug(
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
)
return None
config = getattr(model, "args", None)
if config is None:
logger.debug("Could not get model config for MTP module")
return None
embed_tokens = getattr(inner_model, "embed_tokens", None)
lm_head = getattr(model, "lm_head", None)
norm = getattr(inner_model, "norm", None)
if embed_tokens is None or lm_head is None or norm is None:
logger.debug("Could not get required model components for MTP")
return None
mtp_module = MTPModule(
config=config,
shared_embedding=embed_tokens,
shared_lm_head=lm_head,
output_norm=norm,
)
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
model_weights = _flatten_params(raw_params)
mtp_weights = extract_mtp_weights(model_weights)
if not mtp_weights:
logger.debug("No MTP weights found in model parameters")
return None
load_mtp_weights_into_module(mtp_module, mtp_weights)
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
inner_model.layers = new_layers # noqa: B010
logger.info(
f"Extracted MTP module, main model now has {len(new_layers)} layers"
)
return mtp_module
except Exception as e:
logger.warning(f"Failed to extract MTP module: {e}")
return None
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -339,52 +201,28 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
model_id = bound_instance.bound_shard.model_card.model_id
mtp_module = None
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
# Patch sanitize for MTP if this might be DeepSeek V3
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
if should_try_mtp:
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
_patch_deepseek_sanitize_for_mtp()
try:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=not should_try_mtp)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
# Extract MTP module if available
if should_try_mtp and _is_deepseek_v3_model(model):
mtp_module = _extract_mtp_module(model)
if mtp_module is not None:
logger.info("Successfully extracted MTP module from DeepSeek V3")
finally:
if should_try_mtp:
_restore_deepseek_sanitize()
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
# Store MTP module on the model for later access
if mtp_module is not None:
model.mtp_module = mtp_module # noqa: B010
return cast(Model, model), tokenizer
@@ -652,6 +490,30 @@ def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> b
return think_token is not None and prompt.rstrip().endswith(think_token)
def fix_unmatched_think_end_tokens(
tokens: mx.array, tokenizer: TokenizerWrapper
) -> mx.array:
if not tokenizer.has_thinking:
return tokens
assert tokenizer.think_start_id
assert tokenizer.think_end_id
think_start_id: int = tokenizer.think_start_id
think_end_id: int = tokenizer.think_end_id
token_list: list[int] = cast(list[int], tokens.tolist())
result: list[int] = []
depth = 0
for token in token_list:
if token == think_start_id:
depth += 1
elif token == think_end_id:
if depth == 0:
result.append(think_start_id)
else:
depth -= 1
result.append(token)
return mx.array(result)
class NullKVCache(KVCache):
"""
A KVCache that pretends to exist but holds zero tokens.

View File

@@ -98,21 +98,23 @@ class Worker:
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
try:
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:

View File

@@ -193,7 +193,7 @@ def main(
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer, group)
kv_prefix_cache = KVPrefixCache(group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
@@ -226,6 +226,7 @@ def main(
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
group=group,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
@@ -274,6 +275,7 @@ def main(
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=group,
)
# For other thinking models (GLM, etc.), check if we need to
@@ -627,7 +629,7 @@ def parse_thinking_models(
yield response.model_copy(
update={
"text": tokenizer.think_start,
"token": tokenizer.think_start_id, # type: ignore
"token": tokenizer.think_start_id,
}
)
yield response

View File

@@ -8,10 +8,8 @@ import anyio
from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
to_thread,
)
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import (
@@ -49,7 +47,6 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@@ -93,28 +90,29 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._forward_events)
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
@@ -122,10 +120,6 @@ class RunnerSupervisor:
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(

View File

@@ -88,12 +88,12 @@ class TestKVPrefix:
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache = KVPrefixCache()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache = KVPrefixCache()
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
@@ -101,7 +101,7 @@ class TestKVPrefix:
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache = KVPrefixCache()
cache.clear()
assert len(cache.prompts) == 0
@@ -142,10 +142,12 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert cache_length(cache) == len(tokens)
# Cache should now hold the prompt tokens minus one
assert cache_length(cache) == len(tokens) - 1
# Snapshots should be available for models with non-KV caches
assert len(snapshots) > 0
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
@@ -159,10 +161,10 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
assert len(kv_prefix_cache.prompts) == 1
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -170,7 +172,7 @@ class TestKVPrefixCacheWithModel:
# Retrieve with same prompt: exact match
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
model, tokens
)
assert matched_index == 0
@@ -191,10 +193,12 @@ class TestKVPrefixCacheWithModel:
short_tokens = encode_prompt(tokenizer, short_prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache
)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(short_prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
# Query with longer prompt that shares the chat template prefix
long_task = TextGenerationTaskParams(
@@ -212,13 +216,12 @@ class TestKVPrefixCacheWithModel:
)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, long_prompt
model, long_tokens
)
assert matched_index == 0
# remaining_tokens should be the suffix after the shared prefix
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
# remaining_tokens covers from snapshot restore position to end
assert len(remaining_tokens) >= len(long_tokens) - expected_prefix
def test_stored_cache_not_mutated_after_get_and_generation(
self, model_and_tokenizer
@@ -235,15 +238,15 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)
assert matched_index == 0
# Simulate generation: feed many additional tokens through the cache
@@ -273,15 +276,15 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
@@ -298,7 +301,7 @@ class TestKVPrefixCacheWithModel:
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello")],
@@ -328,7 +331,7 @@ class TestKVPrefixCacheWithModel:
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Reuse test")],
@@ -352,20 +355,20 @@ class TestKVPrefixCacheWithModel:
# Second call should find a prefix match (the stored cache contains
# prompt + generated tokens, which shares the prompt prefix)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
model, prompt_tokens
)
# The stored cache is longer than the prompt (it includes generated tokens),
# so this is a prefix match where our prompt is fully contained
assert matched_index == 0
# Exact match: remaining_tokens is just the last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
# Exact match: remaining_tokens is just the last token and the one before
assert len(remaining_tokens) == 2
assert mx.array_equal(remaining_tokens, prompt_tokens[-2:])
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
@@ -444,7 +447,7 @@ class TestKVPrefixCacheWithModel:
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Immutable test")],
@@ -481,7 +484,7 @@ class TestKVPrefixCacheWithModel:
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
@@ -495,7 +498,7 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -505,19 +508,10 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache._last_used[2] = 100.0
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
# Simulate memory pressure: active memory exceeds threshold
fake_limit = 1000
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
with (
patch(
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
return_value=fake_active,
),
patch(
"exo.worker.engines.mlx.cache.mx.metal.device_info",
return_value={"max_recommended_working_set_size": fake_limit},
),
# Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)
with patch(
"exo.worker.engines.mlx.cache.get_memory_used_percentage",
return_value=0.95,
):
# Trigger eviction by adding a new entry
task = TextGenerationTaskParams(
@@ -529,14 +523,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
# Since fake_active stays above threshold after each eviction (we don't change it),
# all old entries get evicted, leaving only the newly added one
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)
assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)

View File

@@ -34,6 +34,7 @@ TOKENIZER_FILE_PATTERNS = [
"added_tokens.json",
"tokenizer.model",
"tokenization_*.py", # Custom tokenizer implementations
"tool_declaration_ts.py", # Dependency of tokenization_kimi.py
]

53
tests/auto_bench.sh Executable file
View File

@@ -0,0 +1,53 @@
#!/usr/bin/env bash
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
sleep 1
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
done
wait
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
done
for host; do
echo "Waiting for $host..." 1>&2
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
echo "Waiting 30s for cluster setup" 1>&2
sleep 30
echo "EXO loaded" 1>&2
bench_runner="${hosts[0]}"
mkdir -p "./bench/$commit"
nix run .#exo-get-all-models-on-cluster -- "$bench_runner" | while IFS= read -r model; do
echo "running bench for $model" 1>&2
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$bench_runner@$bench_runner" "/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-bench -- --model $model --pp 128 4096 --tg 128 --stdout --skip-tensor-ring" >>"./bench/$commit/${model//\//--}.json"
echo
done

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# pyright: reportAny=false
import json
import subprocess
import sys
from typing import Any, cast
from urllib.request import urlopen
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = next(
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None
) or sys.exit(f"{h} not found in tailscale")
with urlopen(f"http://{ip}:52415/state", timeout=5) as r:
data = json.loads(r.read()).get("downloads", {})
def mid(x: dict[str, Any]) -> str | None:
for k in (
"DownloadCompleted",
"shardMetadata",
"PipelineShardMetadata",
"modelCard",
"modelId",
):
x = x.get(k, {})
return cast(str | None, x if x != {} else None)
common = set[str].intersection(
*[{m for d in nid if (m := mid(d))} for nid in data.values()]
)
for c in common:
print(c)

View File

@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
@@ -34,21 +34,13 @@ reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
wait

2173
uv.lock generated
View File

File diff suppressed because it is too large Load Diff