mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-07 04:32:28 -05:00
Compare commits
16 Commits
alexcheema
...
jaccl-buil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
103cbdee58 | ||
|
|
dbcc829625 | ||
|
|
30b384e2e6 | ||
|
|
6675feed71 | ||
|
|
9b5cae3db6 | ||
|
|
cf7201f91e | ||
|
|
b315035ae0 | ||
|
|
c8dbbee27b | ||
|
|
f0107e9670 | ||
|
|
9f502793c1 | ||
|
|
c8371349d5 | ||
|
|
6b907398a4 | ||
|
|
572e647908 | ||
|
|
e59ebd986d | ||
|
|
5c2f29f3f2 | ||
|
|
ffe6396c91 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -32,3 +32,6 @@ dashboard/.svelte-kit/
|
||||
# host config snapshots
|
||||
hosts_*.json
|
||||
.swp
|
||||
|
||||
# bench files
|
||||
bench/**/*.json
|
||||
|
||||
@@ -1139,7 +1139,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`flatten`."""
|
||||
|
||||
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
|
||||
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`reshape` but the shape can be passed either as a
|
||||
:obj:`tuple` or as separate arguments.
|
||||
@@ -1222,7 +1222,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`swapaxes`."""
|
||||
|
||||
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
|
||||
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`transpose` but the axes can be passed either as
|
||||
a tuple or as separate arguments.
|
||||
|
||||
@@ -30,6 +30,9 @@ class Conv1d(Module):
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
weight: mx.array
|
||||
groups: int
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
||||
@@ -11,7 +11,10 @@ import mlx.core as mx
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
|
||||
offset: int
|
||||
def update_and_fetch(
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
@@ -87,6 +90,7 @@ def create_attention_mask(
|
||||
class _BaseCache(Cache):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.mla import MultiLinear
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
@@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module):
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
# kv_b_proj: nn.Linear
|
||||
embed_q: MultiLinear
|
||||
unembed_out: MultiLinear
|
||||
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
|
||||
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Type stubs for mlx_lm.models.qwen3_next"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
class Qwen3NextMLP(nn.Module):
|
||||
gate_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextGatedDeltaNet(nn.Module):
|
||||
hidden_size: int
|
||||
num_v_heads: int
|
||||
num_k_heads: int
|
||||
head_k_dim: int
|
||||
head_v_dim: int
|
||||
key_dim: int
|
||||
value_dim: int
|
||||
conv_kernel_size: int
|
||||
conv_dim: int
|
||||
conv1d: nn.Conv1d
|
||||
in_proj_qkvz: nn.Linear
|
||||
in_proj_ba: nn.Linear
|
||||
dt_bias: mx.array
|
||||
A_log: mx.array
|
||||
out_proj: nn.Linear
|
||||
|
||||
def __init__(self, config: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
norm_topk_prob: bool
|
||||
num_experts: int
|
||||
top_k: int
|
||||
gate: nn.Linear
|
||||
switch_mlp: SwitchGLU
|
||||
shared_expert: Qwen3NextMLP
|
||||
shared_expert_gate: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextDecoderLayer(nn.Module):
|
||||
is_linear: bool
|
||||
linear_attn: Qwen3NextGatedDeltaNet
|
||||
self_attn: Qwen3NextAttention
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock
|
||||
|
||||
def __init__(self, args: Any, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextModel(nn.Module):
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Qwen3NextDecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: Qwen3NextModel
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[Qwen3NextDecoderLayer]: ...
|
||||
@@ -113,6 +113,10 @@ class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
think_start: str | None
|
||||
think_end: str | None
|
||||
think_start_id: int | None
|
||||
think_end_id: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -431,7 +431,12 @@ def main() -> int:
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
@@ -450,6 +455,7 @@ def main() -> int:
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument("--stdout", action="store_true", help="Write results to stdout")
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
@@ -533,6 +539,16 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
@@ -652,7 +668,9 @@ def main() -> int:
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.json_out:
|
||||
if args.stdout:
|
||||
json.dump(all_rows, sys.stdout, indent=2, ensure_ascii=False)
|
||||
elif args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
isAdding: boolean;
|
||||
onAdd: () => void;
|
||||
onSelect: () => void;
|
||||
downloadedOnNodes?: string[];
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -22,6 +23,7 @@
|
||||
isAdding,
|
||||
onAdd,
|
||||
onSelect,
|
||||
downloadedOnNodes = [],
|
||||
}: HuggingFaceResultItemProps = $props();
|
||||
|
||||
function formatNumber(num: number): string {
|
||||
@@ -45,6 +47,28 @@
|
||||
<span class="text-sm font-mono text-white truncate" title={model.id}
|
||||
>{modelName}</span
|
||||
>
|
||||
{#if downloadedOnNodes.length > 0}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
{#if isAdded}
|
||||
<span
|
||||
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
type ModelFilterPopoverProps = {
|
||||
@@ -148,6 +149,36 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Downloaded only -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
|
||||
? 'bg-green-500/20 text-green-400 border border-green-500/30'
|
||||
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
|
||||
onclick={() =>
|
||||
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 inline-block"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="ml-1">Downloaded</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Size range -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
hasMultipleVariants: boolean;
|
||||
}
|
||||
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
type ModelPickerGroupProps = {
|
||||
group: ModelGroup;
|
||||
isExpanded: boolean;
|
||||
@@ -31,6 +37,7 @@
|
||||
onSelectModel: (modelId: string) => void;
|
||||
onToggleFavorite: (baseModelId: string) => void;
|
||||
onShowInfo: (group: ModelGroup) => void;
|
||||
downloadStatusMap?: Map<string, DownloadAvailability>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -43,8 +50,19 @@
|
||||
onSelectModel,
|
||||
onToggleFavorite,
|
||||
onShowInfo,
|
||||
downloadStatusMap,
|
||||
}: ModelPickerGroupProps = $props();
|
||||
|
||||
// Group-level download status: show if any variant is downloaded
|
||||
const groupDownloadStatus = $derived.by(() => {
|
||||
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
|
||||
// Return the first available entry (prefer "available" ones)
|
||||
for (const avail of downloadStatusMap.values()) {
|
||||
if (avail.available) return avail;
|
||||
}
|
||||
return downloadStatusMap.values().next().value;
|
||||
});
|
||||
|
||||
// Format storage size
|
||||
function formatSize(mb: number | undefined): string {
|
||||
if (!mb) return "";
|
||||
@@ -198,10 +216,42 @@
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Variant count -->
|
||||
<!-- Variant count with size range -->
|
||||
{#if group.hasMultipleVariants}
|
||||
{@const sizes = group.variants
|
||||
.map((v) => v.storage_size_megabytes || 0)
|
||||
.filter((s) => s > 0)
|
||||
.sort((a, b) => a - b)}
|
||||
<span class="text-xs font-mono text-white/30 flex-shrink-0">
|
||||
{group.variants.length} variants
|
||||
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
|
||||
sizes[0],
|
||||
)}-{formatSize(sizes[sizes.length - 1])}){/if}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Download availability indicator -->
|
||||
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={groupDownloadStatus.available
|
||||
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
|
||||
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
@@ -305,6 +355,33 @@
|
||||
{formatSize(variant.storage_size_megabytes)}
|
||||
</span>
|
||||
|
||||
<!-- Download indicator for this variant -->
|
||||
{#if downloadStatusMap?.get(variant.id)}
|
||||
{@const variantDl = downloadStatusMap.get(variant.id)}
|
||||
{#if variantDl}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
<!-- Check mark if selected -->
|
||||
{#if isSelected}
|
||||
<svg
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import ModelPickerGroup from "./ModelPickerGroup.svelte";
|
||||
import ModelFilterPopover from "./ModelFilterPopover.svelte";
|
||||
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
|
||||
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
|
||||
|
||||
interface ModelInfo {
|
||||
id: string;
|
||||
@@ -33,6 +34,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
interface HuggingFaceModel {
|
||||
@@ -58,6 +60,15 @@
|
||||
onDeleteModel: (modelId: string) => Promise<void>;
|
||||
totalMemoryGB: number;
|
||||
usedMemoryGB: number;
|
||||
downloadsData?: Record<string, unknown[]>;
|
||||
topologyNodes?: Record<
|
||||
string,
|
||||
{
|
||||
friendly_name?: string;
|
||||
system_info?: { model_id?: string };
|
||||
macmon_info?: { memory?: { ram_total?: number } };
|
||||
}
|
||||
>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -74,6 +85,8 @@
|
||||
onDeleteModel,
|
||||
totalMemoryGB,
|
||||
usedMemoryGB,
|
||||
downloadsData,
|
||||
topologyNodes,
|
||||
}: ModelPickerModalProps = $props();
|
||||
|
||||
// Local state
|
||||
@@ -81,9 +94,75 @@
|
||||
let selectedFamily = $state<string | null>(null);
|
||||
let expandedGroups = $state<Set<string>>(new Set());
|
||||
let showFilters = $state(false);
|
||||
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
|
||||
let filters = $state<FilterState>({
|
||||
capabilities: [],
|
||||
sizeRange: null,
|
||||
downloadedOnly: false,
|
||||
});
|
||||
let infoGroup = $state<ModelGroup | null>(null);
|
||||
|
||||
// Download availability per model group
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = topologyNodes?.[nodeId];
|
||||
return (
|
||||
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
|
||||
);
|
||||
}
|
||||
|
||||
const modelDownloadAvailability = $derived.by(() => {
|
||||
const result = new Map<string, DownloadAvailability>();
|
||||
if (!downloadsData || !topologyNodes) return result;
|
||||
|
||||
for (const model of models) {
|
||||
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
|
||||
if (nodeIds.length === 0) continue;
|
||||
|
||||
// Sum total RAM across nodes that have the model
|
||||
let totalRamBytes = 0;
|
||||
for (const nodeId of nodeIds) {
|
||||
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
|
||||
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
|
||||
}
|
||||
|
||||
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
|
||||
result.set(model.id, {
|
||||
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
|
||||
nodeNames: nodeIds.map(getNodeName),
|
||||
nodeIds,
|
||||
});
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
// Aggregate download availability per group (available if ANY variant is available)
|
||||
function getGroupDownloadAvailability(
|
||||
group: ModelGroup,
|
||||
): DownloadAvailability | undefined {
|
||||
for (const variant of group.variants) {
|
||||
const avail = modelDownloadAvailability.get(variant.id);
|
||||
if (avail && avail.nodeIds.length > 0) return avail;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Get per-variant download map for a group
|
||||
function getVariantDownloadMap(
|
||||
group: ModelGroup,
|
||||
): Map<string, DownloadAvailability> {
|
||||
const map = new Map<string, DownloadAvailability>();
|
||||
for (const variant of group.variants) {
|
||||
const avail = modelDownloadAvailability.get(variant.id);
|
||||
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
// HuggingFace Hub state
|
||||
let hfSearchQuery = $state("");
|
||||
let hfSearchResults = $state<HuggingFaceModel[]>([]);
|
||||
@@ -95,15 +174,12 @@
|
||||
let manualModelId = $state("");
|
||||
let addModelError = $state<string | null>(null);
|
||||
|
||||
// Reset state when modal opens
|
||||
// Reset transient state when modal opens, but preserve tab selection
|
||||
$effect(() => {
|
||||
if (isOpen) {
|
||||
searchQuery = "";
|
||||
selectedFamily = null;
|
||||
expandedGroups = new Set();
|
||||
showFilters = false;
|
||||
hfSearchQuery = "";
|
||||
hfSearchResults = [];
|
||||
manualModelId = "";
|
||||
addModelError = null;
|
||||
}
|
||||
@@ -339,6 +415,16 @@
|
||||
});
|
||||
}
|
||||
|
||||
// Filter to downloaded models only
|
||||
if (filters.downloadedOnly) {
|
||||
result = result.filter((g) =>
|
||||
g.variants.some((v) => {
|
||||
const avail = modelDownloadAvailability.get(v.id);
|
||||
return avail && avail.nodeIds.length > 0;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Sort: models that fit first, then by size (largest first)
|
||||
result.sort((a, b) => {
|
||||
const aFits = a.variants.some((v) => canModelFit(v.id));
|
||||
@@ -385,11 +471,13 @@
|
||||
}
|
||||
|
||||
function clearFilters() {
|
||||
filters = { capabilities: [], sizeRange: null };
|
||||
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
|
||||
}
|
||||
|
||||
const hasActiveFilters = $derived(
|
||||
filters.capabilities.length > 0 || filters.sizeRange !== null,
|
||||
filters.capabilities.length > 0 ||
|
||||
filters.sizeRange !== null ||
|
||||
filters.downloadedOnly,
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -576,6 +664,12 @@
|
||||
isAdding={addingModelId === model.id}
|
||||
onAdd={() => handleAddModel(model.id)}
|
||||
onSelect={() => handleSelectHfModel(model.id)}
|
||||
downloadedOnNodes={downloadsData
|
||||
? getNodesWithModelDownloaded(
|
||||
downloadsData,
|
||||
model.id,
|
||||
).map(getNodeName)
|
||||
: []}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
@@ -650,6 +744,7 @@
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
onShowInfo={(g) => (infoGroup = g)}
|
||||
downloadStatusMap={getVariantDownloadMap(group)}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
@@ -667,6 +762,11 @@
|
||||
>{cap}</span
|
||||
>
|
||||
{/each}
|
||||
{#if filters.downloadedOnly}
|
||||
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
|
||||
>Downloaded</span
|
||||
>
|
||||
{/if}
|
||||
{#if filters.sizeRange}
|
||||
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
|
||||
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
|
||||
@@ -742,6 +842,40 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
|
||||
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
|
||||
{#if infoDownload}
|
||||
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
|
||||
<div class="flex items-center gap-2 mb-1">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="text-white/40">Downloaded on:</span>
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1 mt-1">
|
||||
{#each infoDownload.nodeNames as nodeName}
|
||||
<span
|
||||
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
|
||||
>
|
||||
{nodeName}
|
||||
</span>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
152
dashboard/src/lib/utils/downloads.ts
Normal file
152
dashboard/src/lib/utils/downloads.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
/**
|
||||
* Shared utilities for parsing and querying download state.
|
||||
*
|
||||
* The download state from `/state` is shaped as:
|
||||
* Record<NodeId, Array<TaggedDownloadEntry>>
|
||||
*
|
||||
* Each entry is a tagged union object like:
|
||||
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
|
||||
*/
|
||||
|
||||
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
|
||||
function unwrapTagged(
|
||||
obj: Record<string, unknown>,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
const keys = Object.keys(obj);
|
||||
if (keys.length !== 1) return null;
|
||||
const tag = keys[0];
|
||||
const payload = obj[tag];
|
||||
if (!payload || typeof payload !== "object") return null;
|
||||
return [tag, payload as Record<string, unknown>];
|
||||
}
|
||||
|
||||
/** Extract the model ID string from a download entry's nested shard_metadata. */
|
||||
export function extractModelIdFromDownload(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): string | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
|
||||
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
|
||||
if (!unwrapped) return null;
|
||||
const [, shardData] = unwrapped;
|
||||
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
if (!modelMeta || typeof modelMeta !== "object") return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
|
||||
}
|
||||
|
||||
/** Extract the shard_metadata object from a download entry payload. */
|
||||
export function extractShardMetadata(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): Record<string, unknown> | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
return shardMetadata as Record<string, unknown>;
|
||||
}
|
||||
|
||||
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
|
||||
export function getDownloadTag(
|
||||
entry: unknown,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
if (!entry || typeof entry !== "object") return null;
|
||||
return unwrapTagged(entry as Record<string, unknown>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
|
||||
*/
|
||||
function* iterNodeDownloads(
|
||||
nodeDownloads: unknown[],
|
||||
): Generator<[string, Record<string, unknown>, string]> {
|
||||
for (const entry of nodeDownloads) {
|
||||
const tagged = getDownloadTag(entry);
|
||||
if (!tagged) continue;
|
||||
const [tag, payload] = tagged;
|
||||
const modelId = extractModelIdFromDownload(payload);
|
||||
if (!modelId) continue;
|
||||
yield [tag, payload, modelId];
|
||||
}
|
||||
}
|
||||
|
||||
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
|
||||
export function isModelDownloadedOnNode(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): boolean {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return false;
|
||||
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
|
||||
export function getNodesWithModelDownloaded(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): string[] {
|
||||
const result: string[] = [];
|
||||
for (const nodeId of Object.keys(downloadsData)) {
|
||||
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
|
||||
result.push(nodeId);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find shard metadata for a model from any download entry across all nodes.
|
||||
* Returns the first match found (completed entries are preferred).
|
||||
*/
|
||||
export function getShardMetadataForModel(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): Record<string, unknown> | null {
|
||||
let fallback: Record<string, unknown> | null = null;
|
||||
|
||||
for (const nodeDownloads of Object.values(downloadsData)) {
|
||||
if (!Array.isArray(nodeDownloads)) continue;
|
||||
|
||||
for (const [tag, payload, entryModelId] of iterNodeDownloads(
|
||||
nodeDownloads,
|
||||
)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
const shard = extractShardMetadata(payload);
|
||||
if (!shard) continue;
|
||||
|
||||
if (tag === "DownloadCompleted") return shard;
|
||||
if (!fallback) fallback = shard;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the download status tag for a specific model on a specific node.
|
||||
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
|
||||
*/
|
||||
export function getModelDownloadStatus(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): string | null {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return null;
|
||||
|
||||
let best: string | null = null;
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
if (tag === "DownloadCompleted") return tag;
|
||||
if (tag === "DownloadOngoing") best = tag;
|
||||
else if (!best) best = tag;
|
||||
}
|
||||
return best;
|
||||
}
|
||||
@@ -3264,4 +3264,6 @@
|
||||
onDeleteModel={deleteCustomModel}
|
||||
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
|
||||
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
|
||||
{downloadsData}
|
||||
topologyNodes={data?.nodes}
|
||||
/>
|
||||
|
||||
16
flake.nix
16
flake.nix
@@ -83,6 +83,9 @@
|
||||
_module.args.pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
|
||||
overlays = [
|
||||
(final: _: { apple-sdk_26 = final.callPackage ./nix/apple-sdk/package.nix { darwinSdkMajorVersion = "26"; }; })
|
||||
];
|
||||
};
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
@@ -105,7 +108,10 @@
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
shfmt.enable = true;
|
||||
shfmt = {
|
||||
enable = true;
|
||||
excludes = [ "nix/apple-sdk/**" ];
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -118,9 +124,15 @@
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit (self'.packages) metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
default = self'.packages.exo;
|
||||
sdk-version = pkgs.runCommand "sdk-version" { } ''
|
||||
mkdir -p $out
|
||||
echo ${pkgs.apple-sdk_26.version} > $out/version
|
||||
'';
|
||||
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
2
justfile
2
justfile
@@ -20,7 +20,7 @@ sync-clean:
|
||||
|
||||
rust-rebuild:
|
||||
cargo run --bin stub_gen
|
||||
just sync-clean
|
||||
uv sync --reinstall-package exo_pyo3_bindings
|
||||
|
||||
build-dashboard:
|
||||
#!/usr/bin/env bash
|
||||
|
||||
0
nix/apple-sdk/README.md
Normal file
0
nix/apple-sdk/README.md
Normal file
48
nix/apple-sdk/common/add-core-symbolication.nix
Normal file
48
nix/apple-sdk/common/add-core-symbolication.nix
Normal file
@@ -0,0 +1,48 @@
|
||||
{ lib
|
||||
, fetchFromGitHub
|
||||
, stdenvNoCC
|
||||
,
|
||||
}:
|
||||
|
||||
let
|
||||
CoreSymbolication = stdenvNoCC.mkDerivation (finalAttrs: {
|
||||
pname = "CoreSymbolication";
|
||||
version = "0-unstable-2018-06-17";
|
||||
|
||||
src = fetchFromGitHub {
|
||||
repo = "CoreSymbolication";
|
||||
owner = "matthewbauer";
|
||||
rev = "24c87c23664b3ee05dc7a5a87d647ae476a680e4";
|
||||
hash = "sha256-PzvLq94eNhP0+rLwGMKcMzxuD6MlrNI7iT/eV0obtSE=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
# Add missing symbol definitions needed to build `zlog` in system_cmds.
|
||||
# https://github.com/matthewbauer/CoreSymbolication/pull/2
|
||||
../patches/0001-Add-function-definitions-needed-to-build-zlog-in-sys.patch
|
||||
../patches/0002-Add-CF_EXPORT-To-const-symbols.patch
|
||||
];
|
||||
|
||||
dontBuild = true;
|
||||
|
||||
installPhase = ''
|
||||
mkdir -p "$out/include"
|
||||
cp *.h "$out/include"
|
||||
'';
|
||||
|
||||
meta = {
|
||||
description = "Reverse engineered headers for Apple's CoreSymbolication framework";
|
||||
homepage = "https://github.com/matthewbauer/CoreSymbolication";
|
||||
license = lib.licenses.mit;
|
||||
teams = [ lib.teams.darwin ];
|
||||
platforms = lib.platforms.darwin;
|
||||
};
|
||||
});
|
||||
in
|
||||
self: super: {
|
||||
buildPhase = super.buildPhase or "" + ''
|
||||
mkdir -p System/Library/PrivateFrameworks/CoreSymbolication.framework/Versions/A/Headers
|
||||
ln -s Versions/Current/Headers System/Library/PrivateFrameworks/CoreSymbolication.framework/Headers
|
||||
cp '${CoreSymbolication}/include/'*.h System/Library/PrivateFrameworks/CoreSymbolication.framework/Versions/A/Headers
|
||||
'';
|
||||
}
|
||||
13
nix/apple-sdk/common/derivation-options.nix
Normal file
13
nix/apple-sdk/common/derivation-options.nix
Normal file
@@ -0,0 +1,13 @@
|
||||
{ lib, config }:
|
||||
|
||||
self: super: {
|
||||
preBuild = super.preBuild or "" + ''
|
||||
platformPath=$out/Platforms/MacOSX.platform
|
||||
sdkpath=$platformPath/Developer/SDKs
|
||||
'';
|
||||
|
||||
preInstall = super.preInstall or "" + ''
|
||||
platformPath=$out/Platforms/MacOSX.platform
|
||||
sdkpath=$platformPath/Developer/SDKs
|
||||
'';
|
||||
}
|
||||
38
nix/apple-sdk/common/fetch-sdk.nix
Normal file
38
nix/apple-sdk/common/fetch-sdk.nix
Normal file
@@ -0,0 +1,38 @@
|
||||
{ lib
|
||||
, fetchurl
|
||||
, cpio
|
||||
, pbzx
|
||||
,
|
||||
}:
|
||||
|
||||
{ urls
|
||||
, version
|
||||
, hash
|
||||
,
|
||||
}:
|
||||
|
||||
fetchurl {
|
||||
pname = "macOS-SDK";
|
||||
inherit version urls hash;
|
||||
|
||||
recursiveHash = true;
|
||||
|
||||
nativeBuildInputs = [
|
||||
cpio
|
||||
pbzx
|
||||
];
|
||||
|
||||
postFetch = ''
|
||||
renamed=$(mktemp -d)/sdk.xar
|
||||
mv "$downloadedFile" "$renamed"
|
||||
pbzx "$renamed" | cpio -idm
|
||||
|
||||
src=Library/Developer/CommandLineTools/SDKs/MacOSX${lib.versions.majorMinor version}.sdk
|
||||
|
||||
# Remove unwanted binaries, man pages, and folders from the SDK.
|
||||
rm -rf $src/usr/bin $src/usr/share $src/System/Library/Perl
|
||||
|
||||
mkdir -p "$out"
|
||||
cp -rd $src/* "$out"
|
||||
'';
|
||||
}
|
||||
10
nix/apple-sdk/common/passthru-private-frameworks.nix
Normal file
10
nix/apple-sdk/common/passthru-private-frameworks.nix
Normal file
@@ -0,0 +1,10 @@
|
||||
{ makeSetupHook, sdkVersion }:
|
||||
|
||||
self: super: {
|
||||
passthru = super.passthru or { } // {
|
||||
privateFrameworksHook = makeSetupHook
|
||||
{
|
||||
name = "apple-sdk-private-frameworks-hook";
|
||||
} ../setup-hooks/add-private-frameworks.sh;
|
||||
};
|
||||
}
|
||||
38
nix/apple-sdk/common/passthru-source-release-files.nix
Normal file
38
nix/apple-sdk/common/passthru-source-release-files.nix
Normal file
@@ -0,0 +1,38 @@
|
||||
let
|
||||
lockfile = builtins.fromJSON (builtins.readFile ../metadata/apple-oss-lockfile.json);
|
||||
in
|
||||
|
||||
{ lib
|
||||
, fetchFromGitHub
|
||||
, stdenvNoCC
|
||||
, sdkVersion
|
||||
,
|
||||
}:
|
||||
|
||||
let
|
||||
sdkinfo = lockfile.${sdkVersion};
|
||||
in
|
||||
self: super: {
|
||||
passthru = super.passthru or { } // {
|
||||
# Returns the raw source from apple-oss-distributions repo.
|
||||
# This is mostly useful for copying private headers needed to build other source releases.
|
||||
#
|
||||
# Note: The source releases are mostly not used to build the SDK. Unless they can be used to build binaries,
|
||||
# they’re not used.
|
||||
sourceRelease =
|
||||
name:
|
||||
let
|
||||
lockinfo = sdkinfo.${name};
|
||||
in
|
||||
fetchFromGitHub
|
||||
{
|
||||
owner = "apple-oss-distributions";
|
||||
repo = name;
|
||||
rev = lockinfo.rev or "${name}-${lockinfo.version}";
|
||||
inherit (lockinfo) hash;
|
||||
}
|
||||
// {
|
||||
inherit (lockinfo) version;
|
||||
};
|
||||
};
|
||||
}
|
||||
327
nix/apple-sdk/common/plists.nix
Normal file
327
nix/apple-sdk/common/plists.nix
Normal file
@@ -0,0 +1,327 @@
|
||||
{ lib
|
||||
, stdenvNoCC
|
||||
, xcodePlatform
|
||||
, sdkVersion
|
||||
,
|
||||
}:
|
||||
|
||||
let
|
||||
inherit (lib.generators) toPlist;
|
||||
|
||||
Info = rec {
|
||||
CFBundleIdentifier = "com.apple.platform.${Name}";
|
||||
DefaultProperties = {
|
||||
COMPRESS_PNG_FILES = "NO";
|
||||
DEPLOYMENT_TARGET_SETTING_NAME = stdenvNoCC.hostPlatform.darwinMinVersionVariable;
|
||||
STRIP_PNG_TEXT = "NO";
|
||||
};
|
||||
Description = if stdenvNoCC.hostPlatform.isMacOS then "macOS" else "iOS";
|
||||
FamilyIdentifier = lib.toLower xcodePlatform;
|
||||
FamilyName = Description;
|
||||
Identifier = CFBundleIdentifier;
|
||||
MinimumSDKVersion = stdenvNoCC.hostPlatform.darwinMinVersion;
|
||||
Name = lib.toLower xcodePlatform;
|
||||
Type = "Platform";
|
||||
Version = sdkVersion;
|
||||
};
|
||||
|
||||
# These files are all based off of Xcode spec files found in
|
||||
# /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/Library/Xcode/PrivatePlugIns/IDEOSXSupportCore.ideplugin/Contents/Resources.
|
||||
|
||||
# Based off of the "MacOSX Architectures.xcspec" file. All i386 stuff
|
||||
# is removed because NixPkgs only supports darwin-x86_64 and darwin-arm64.
|
||||
Architectures = [
|
||||
{
|
||||
Identifier = "Standard";
|
||||
Type = "Architecture";
|
||||
Name = "Standard Architectures (Apple Silicon, 64-bit Intel)";
|
||||
RealArchitectures = [
|
||||
"arm64"
|
||||
"x86_64"
|
||||
];
|
||||
ArchitectureSetting = "ARCHS_STANDARD";
|
||||
}
|
||||
{
|
||||
Identifier = "Universal";
|
||||
Type = "Architecture";
|
||||
Name = "Universal (Apple Silicon, 64-bit Intel)";
|
||||
RealArchitectures = [
|
||||
"arm64"
|
||||
"x86_64"
|
||||
];
|
||||
ArchitectureSetting = "ARCHS_STANDARD_32_64_BIT";
|
||||
}
|
||||
{
|
||||
Identifier = "Native";
|
||||
Type = "Architecture";
|
||||
Name = "Native Architecture of Build Machine";
|
||||
ArchitectureSetting = "NATIVE_ARCH_ACTUAL";
|
||||
}
|
||||
{
|
||||
Identifier = "Standard64bit";
|
||||
Type = "Architecture";
|
||||
Name = "Apple Silicon, 64-bit Intel";
|
||||
RealArchitectures = [
|
||||
"arm64"
|
||||
"x86_64"
|
||||
];
|
||||
ArchitectureSetting = "ARCHS_STANDARD_64_BIT";
|
||||
}
|
||||
{
|
||||
Identifier = stdenvNoCC.hostPlatform.darwinArch;
|
||||
Type = "Architecture";
|
||||
Name = "Apple Silicon or Intel 64-bit";
|
||||
}
|
||||
{
|
||||
Identifier = "Standard_Including_64_bit";
|
||||
Type = "Architecture";
|
||||
Name = "Standard Architectures (including 64-bit)";
|
||||
RealArchitectures = [
|
||||
"arm64"
|
||||
"x86_64"
|
||||
];
|
||||
ArchitectureSetting = "ARCHS_STANDARD_INCLUDING_64_BIT";
|
||||
}
|
||||
];
|
||||
|
||||
# Based off of the "MacOSX Package Types.xcspec" file. Only keep the
|
||||
# bare minimum needed.
|
||||
PackageTypes = [
|
||||
{
|
||||
Identifier = "com.apple.package-type.mach-o-executable";
|
||||
Type = "PackageType";
|
||||
Name = "Mach-O Executable";
|
||||
DefaultBuildSettings = {
|
||||
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
|
||||
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
|
||||
};
|
||||
ProductReference = {
|
||||
FileType = "compiled.mach-o.executable";
|
||||
Name = "$(EXECUTABLE_NAME)";
|
||||
};
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.package-type.mach-o-objfile";
|
||||
Type = "PackageType";
|
||||
Name = "Mach-O Object File";
|
||||
DefaultBuildSettings = {
|
||||
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
|
||||
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
|
||||
};
|
||||
ProductReference = {
|
||||
FileType = "compiled.mach-o.objfile";
|
||||
Name = "$(EXECUTABLE_NAME)";
|
||||
};
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.package-type.mach-o-dylib";
|
||||
Type = "PackageType";
|
||||
Name = "Mach-O Dynamic Library";
|
||||
DefaultBuildSettings = {
|
||||
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
|
||||
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
|
||||
};
|
||||
ProductReference = {
|
||||
FileType = "compiled.mach-o.dylib";
|
||||
Name = "$(EXECUTABLE_NAME)";
|
||||
};
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.package-type.static-library";
|
||||
Type = "PackageType";
|
||||
Name = "Mach-O Static Library";
|
||||
DefaultBuildSettings = {
|
||||
EXECUTABLE_PREFIX = "lib";
|
||||
EXECUTABLE_SUFFIX = ".a";
|
||||
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
|
||||
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
|
||||
};
|
||||
ProductReference = {
|
||||
FileType = "archive.ar";
|
||||
Name = "$(EXECUTABLE_NAME)";
|
||||
IsLaunchable = "NO";
|
||||
};
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.package-type.wrapper";
|
||||
Type = "PackageType";
|
||||
Name = "Wrapper";
|
||||
DefaultBuildSettings = {
|
||||
WRAPPER_SUFFIX = ".bundle";
|
||||
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
|
||||
CONTENTS_FOLDER_PATH = "$(WRAPPER_NAME)/Contents";
|
||||
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
|
||||
EXECUTABLE_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/MacOS";
|
||||
EXECUTABLE_PATH = "$(EXECUTABLE_FOLDER_PATH)/$(EXECUTABLE_NAME)";
|
||||
INFOPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/Info.plist";
|
||||
INFOSTRINGS_PATH = "$(LOCALIZED_RESOURCES_FOLDER_PATH)/InfoPlist.strings";
|
||||
PKGINFO_PATH = "$(CONTENTS_FOLDER_PATH)/PkgInfo";
|
||||
PBDEVELOPMENTPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/pbdevelopment.plist";
|
||||
VERSIONPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/version.plist";
|
||||
PUBLIC_HEADERS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Headers";
|
||||
PRIVATE_HEADERS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/PrivateHeaders";
|
||||
EXECUTABLES_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Executables";
|
||||
FRAMEWORKS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Frameworks";
|
||||
SHARED_FRAMEWORKS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/SharedFrameworks";
|
||||
SHARED_SUPPORT_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/SharedSupport";
|
||||
UNLOCALIZED_RESOURCES_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Resources";
|
||||
LOCALIZED_RESOURCES_FOLDER_PATH = "$(UNLOCALIZED_RESOURCES_FOLDER_PATH)/$(DEVELOPMENT_LANGUAGE).lproj";
|
||||
DOCUMENTATION_FOLDER_PATH = "$(LOCALIZED_RESOURCES_FOLDER_PATH)/Documentation";
|
||||
PLUGINS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/PlugIns";
|
||||
SCRIPTS_FOLDER_PATH = "$(UNLOCALIZED_RESOURCES_FOLDER_PATH)/Scripts";
|
||||
};
|
||||
ProductReference = {
|
||||
FileType = "wrapper.cfbundle";
|
||||
Name = "$(WRAPPER_NAME)";
|
||||
IsLaunchable = "NO";
|
||||
};
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.package-type.wrapper.application";
|
||||
Type = "PackageType";
|
||||
BasedOn = "com.apple.package-type.wrapper";
|
||||
Name = "Application Wrapper";
|
||||
DefaultBuildSettings = {
|
||||
GENERATE_PKGINFO_FILE = "YES";
|
||||
};
|
||||
ProductReference = {
|
||||
FileType = "wrapper.application";
|
||||
Name = "$(WRAPPER_NAME)";
|
||||
IsLaunchable = "YES";
|
||||
};
|
||||
}
|
||||
];
|
||||
|
||||
# Based off of the "MacOSX Product Types.xcspec" file. All
|
||||
# bundles/wrapper are removed, because we prefer dynamic products in
|
||||
# NixPkgs.
|
||||
ProductTypes = [
|
||||
{
|
||||
Identifier = "com.apple.product-type.tool";
|
||||
Type = "ProductType";
|
||||
Name = "Command-line Tool";
|
||||
PackageTypes = [ "com.apple.package-type.mach-o-executable" ];
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.product-type.objfile";
|
||||
Type = "ProductType";
|
||||
Name = "Object File";
|
||||
PackageTypes = [ "com.apple.package-type.mach-o-objfile" ];
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.product-type.library.dynamic";
|
||||
Type = "ProductType";
|
||||
Name = "Dynamic Library";
|
||||
PackageTypes = [ "com.apple.package-type.mach-o-dylib" ];
|
||||
DefaultBuildProperties = {
|
||||
FULL_PRODUCT_NAME = "$(EXECUTABLE_NAME)";
|
||||
MACH_O_TYPE = "mh_dylib";
|
||||
REZ_EXECUTABLE = "YES";
|
||||
EXECUTABLE_SUFFIX = ".$(EXECUTABLE_EXTENSION)";
|
||||
EXECUTABLE_EXTENSION = "dylib";
|
||||
DYLIB_COMPATIBILITY_VERSION = "1";
|
||||
DYLIB_CURRENT_VERSION = "1";
|
||||
FRAMEWORK_FLAG_PREFIX = "-framework";
|
||||
LIBRARY_FLAG_PREFIX = "-l";
|
||||
LIBRARY_FLAG_NOSPACE = "YES";
|
||||
STRIP_STYLE = "debugging";
|
||||
GCC_INLINES_ARE_PRIVATE_EXTERN = "YES";
|
||||
CODE_SIGNING_ALLOWED = "YES";
|
||||
CODE_SIGNING_REQUIRED = "NO";
|
||||
};
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.product-type.library.static";
|
||||
Type = "ProductType";
|
||||
Name = "Static Library";
|
||||
PackageTypes = [ "com.apple.package-type.static-library" ];
|
||||
DefaultBuildProperties = {
|
||||
FULL_PRODUCT_NAME = "$(EXECUTABLE_NAME)";
|
||||
MACH_O_TYPE = "staticlib";
|
||||
REZ_EXECUTABLE = "YES";
|
||||
EXECUTABLE_PREFIX = "lib";
|
||||
EXECUTABLE_SUFFIX = ".$(EXECUTABLE_EXTENSION)";
|
||||
EXECUTABLE_EXTENSION = "a";
|
||||
FRAMEWORK_FLAG_PREFIX = "-framework";
|
||||
LIBRARY_FLAG_PREFIX = "-l";
|
||||
LIBRARY_FLAG_NOSPACE = "YES";
|
||||
STRIP_STYLE = "debugging";
|
||||
SEPARATE_STRIP = "YES";
|
||||
CLANG_ENABLE_MODULE_DEBUGGING = "NO";
|
||||
};
|
||||
}
|
||||
{
|
||||
Type = "ProductType";
|
||||
Identifier = "com.apple.product-type.bundle";
|
||||
Name = "Bundle";
|
||||
DefaultBuildProperties = {
|
||||
FULL_PRODUCT_NAME = "$(WRAPPER_NAME)";
|
||||
MACH_O_TYPE = "mh_bundle";
|
||||
WRAPPER_PREFIX = "";
|
||||
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
|
||||
WRAPPER_EXTENSION = "bundle";
|
||||
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
|
||||
FRAMEWORK_FLAG_PREFIX = "-framework";
|
||||
LIBRARY_FLAG_PREFIX = "-l";
|
||||
LIBRARY_FLAG_NOSPACE = "YES";
|
||||
STRIP_STYLE = "non-global";
|
||||
};
|
||||
PackageTypes = [ "com.apple.package-type.wrapper" ];
|
||||
IsWrapper = "YES";
|
||||
HasInfoPlist = "YES";
|
||||
HasInfoPlistStrings = "YES";
|
||||
}
|
||||
{
|
||||
Identifier = "com.apple.product-type.application";
|
||||
Type = "ProductType";
|
||||
BasedOn = "com.apple.product-type.bundle";
|
||||
Name = "Application";
|
||||
DefaultBuildProperties = {
|
||||
MACH_O_TYPE = "mh_execute";
|
||||
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
|
||||
WRAPPER_EXTENSION = "app";
|
||||
};
|
||||
PackageTypes = [ "com.apple.package-type.wrapper.application" ];
|
||||
}
|
||||
{
|
||||
Type = "ProductType";
|
||||
Identifier = "com.apple.product-type.framework";
|
||||
Name = "Bundle";
|
||||
DefaultBuildProperties = {
|
||||
FULL_PRODUCT_NAME = "$(WRAPPER_NAME)";
|
||||
MACH_O_TYPE = "mh_bundle";
|
||||
WRAPPER_PREFIX = "";
|
||||
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
|
||||
WRAPPER_EXTENSION = "bundle";
|
||||
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
|
||||
FRAMEWORK_FLAG_PREFIX = "-framework";
|
||||
LIBRARY_FLAG_PREFIX = "-l";
|
||||
LIBRARY_FLAG_NOSPACE = "YES";
|
||||
STRIP_STYLE = "non-global";
|
||||
};
|
||||
PackageTypes = [ "com.apple.package-type.wrapper" ];
|
||||
IsWrapper = "YES";
|
||||
HasInfoPlist = "YES";
|
||||
HasInfoPlistStrings = "YES";
|
||||
}
|
||||
];
|
||||
|
||||
ToolchainInfo = {
|
||||
Identifier = "com.apple.dt.toolchain.XcodeDefault";
|
||||
};
|
||||
in
|
||||
{
|
||||
"Info.plist" = builtins.toFile "Info.plist" (toPlist { escape = true; } Info);
|
||||
"ToolchainInfo.plist" = builtins.toFile "ToolchainInfo.plist" (
|
||||
toPlist { escape = true; } ToolchainInfo
|
||||
);
|
||||
"Architectures.xcspec" = builtins.toFile "Architectures.xcspec" (
|
||||
toPlist { escape = true; } Architectures
|
||||
);
|
||||
"PackageTypes.xcspec" = builtins.toFile "PackageTypes.xcspec" (
|
||||
toPlist { escape = true; } PackageTypes
|
||||
);
|
||||
"ProductTypes.xcspec" = builtins.toFile "ProductTypes.xcspec" (
|
||||
toPlist { escape = true; } ProductTypes
|
||||
);
|
||||
}
|
||||
40
nix/apple-sdk/common/process-stubs.nix
Normal file
40
nix/apple-sdk/common/process-stubs.nix
Normal file
@@ -0,0 +1,40 @@
|
||||
let
|
||||
removedDylibs = [
|
||||
# corecrypto is available under a very restrictive license (effectively: non-free, can’t use).
|
||||
# Without the headers and not being able to use corecrypto due to its license, it’s not very useful.
|
||||
# Stubs are included in the SDK for all dylibs, including corecrypto. They should be removed.
|
||||
"/usr/lib/system/libcorecrypto.dylib"
|
||||
];
|
||||
in
|
||||
|
||||
{ lib
|
||||
, jq
|
||||
, llvm
|
||||
,
|
||||
}:
|
||||
|
||||
self: super: {
|
||||
nativeBuildInputs = super.nativeBuildInputs or [ ] ++ [
|
||||
jq
|
||||
llvm
|
||||
];
|
||||
|
||||
buildPhase = super.buildPhase or "" + ''
|
||||
echo "Removing the following dylibs from the libSystem reexported libraries list: ${lib.escapeShellArg (lib.concatStringsSep ", " removedDylibs)}"
|
||||
for libSystem in libSystem.B.tbd libSystem.B_asan.tbd; do
|
||||
# tbd-v5 is a JSON-based format, which can be manipulated by `jq`.
|
||||
llvm-readtapi --filetype=tbd-v5 usr/lib/$libSystem \
|
||||
| jq --argjson libs ${lib.escapeShellArg (builtins.toJSON removedDylibs)} '
|
||||
if .libraries then
|
||||
.libraries[] |= select(.install_names[] | any([.] | inside($libs)) | not)
|
||||
else
|
||||
.
|
||||
end
|
||||
| .main_library.reexported_libraries[].names[] |= select([.] | inside($libs) | not)
|
||||
' > usr/lib/$libSystem~
|
||||
# Convert libSystem back to tbd-v4 because not all tooling supports the JSON-based format yet.
|
||||
llvm-readtapi --filetype=tbd-v4 usr/lib/$libSystem~ -o usr/lib/$libSystem
|
||||
rm usr/lib/$libSystem~
|
||||
done
|
||||
'';
|
||||
}
|
||||
74
nix/apple-sdk/common/propagate-inputs.nix
Normal file
74
nix/apple-sdk/common/propagate-inputs.nix
Normal file
@@ -0,0 +1,74 @@
|
||||
{ lib
|
||||
, cups
|
||||
, darwin
|
||||
, db
|
||||
, libiconv
|
||||
, ncurses
|
||||
, stdenv
|
||||
, stdenvNoCC
|
||||
, xcbuild
|
||||
,
|
||||
}:
|
||||
|
||||
let
|
||||
# CUPS has too many dependencies to build as part of the Darwin bootstrap. It’s also typically taken as an explicit
|
||||
# dependency by other packages, so building only the headers (to satisfy other SDK headers) should be okay.
|
||||
cupsHeaders = darwin.bootstrapStdenv.mkDerivation {
|
||||
pname = "${lib.getName cups}-headers";
|
||||
version = lib.getVersion cups;
|
||||
|
||||
inherit (cups) src;
|
||||
|
||||
patches = cups.patches or [ ];
|
||||
|
||||
strictDeps = true;
|
||||
|
||||
dontBuild = true;
|
||||
|
||||
buildInputs = [ darwin.libresolv ]; # The `configure` script requires libresolv headers.
|
||||
|
||||
# CUPS’s configure script fails to find `ar` when cross-compiling.
|
||||
configureFlags = [ "ac_cv_path_AR=${stdenv.cc.targetPrefix}ar" ];
|
||||
|
||||
installTargets = [ "install-headers" ];
|
||||
|
||||
__structuredAttrs = true;
|
||||
|
||||
meta = {
|
||||
inherit (cups.meta)
|
||||
homepage
|
||||
description
|
||||
license
|
||||
maintainers
|
||||
platforms
|
||||
;
|
||||
};
|
||||
};
|
||||
in
|
||||
self: super: {
|
||||
# These packages are propagated only because other platforms include them in their libc (or otherwise by default).
|
||||
# Reducing the number of special cases required to support Darwin makes supporting it easier for package authors.
|
||||
propagatedBuildInputs =
|
||||
super.propagatedBuildInputs or [ ]
|
||||
++ [
|
||||
libiconv
|
||||
darwin.libresolv
|
||||
darwin.libsbuf
|
||||
# Shipped with the SDK only as a library with no headers
|
||||
(lib.getLib darwin.libutil)
|
||||
]
|
||||
# x86_64-darwin links the object files from Csu when targeting very old releases
|
||||
++ lib.optionals stdenvNoCC.hostPlatform.isx86_64 [ darwin.Csu ];
|
||||
|
||||
# The Darwin module for Swift requires certain headers to be included in the SDK (and not just be propagated).
|
||||
buildPhase = super.buildPhase or "" + ''
|
||||
for header in '${lib.getDev libiconv}/include/'* '${lib.getDev ncurses}/include/'* '${cupsHeaders}/include/'*; do
|
||||
ln -s "$header" "usr/include/$(basename "$header")"
|
||||
done
|
||||
'';
|
||||
|
||||
# Exported to allow the headers to pass the requisites check in the stdenv bootstrap.
|
||||
passthru = (super.passthru or { }) // {
|
||||
cups-headers = cupsHeaders;
|
||||
};
|
||||
}
|
||||
53
nix/apple-sdk/common/propagate-xcrun.nix
Normal file
53
nix/apple-sdk/common/propagate-xcrun.nix
Normal file
@@ -0,0 +1,53 @@
|
||||
{ lib
|
||||
, pkgsBuildHost
|
||||
, stdenv
|
||||
, stdenvNoCC
|
||||
, sdkVersion
|
||||
,
|
||||
}:
|
||||
|
||||
let
|
||||
plists = import ./plists.nix {
|
||||
inherit lib stdenvNoCC sdkVersion;
|
||||
xcodePlatform = if stdenvNoCC.hostPlatform.isMacOS then "MacOSX" else "iPhoneOS";
|
||||
};
|
||||
inherit (pkgsBuildHost) darwin cctools xcbuild;
|
||||
in
|
||||
self: super: {
|
||||
propagatedNativeBuildInputs = super.propagatedNativeBuildInputs or [ ] ++ [ xcbuild.xcrun ];
|
||||
|
||||
postInstall = super.postInstall or "" + ''
|
||||
specspath=$out/Library/Xcode/Specifications
|
||||
toolchainsPath=$out/Toolchains/XcodeDefault.xctoolchain
|
||||
mkdir -p "$specspath" "$toolchainsPath"
|
||||
|
||||
# xcbuild expects to find things relative to the plist locations. If these are linked instead of copied,
|
||||
# it won’t find any platforms or SDKs.
|
||||
cp '${plists."Info.plist"}' "$platformPath/Info.plist"
|
||||
cp '${plists."ToolchainInfo.plist"}' "$toolchainsPath/ToolchainInfo.plist"
|
||||
|
||||
for spec in '${xcbuild}/Library/Xcode/Specifications/'*; do
|
||||
ln -s "$spec" "$specspath/$(basename "$spec")"
|
||||
done
|
||||
cp '${plists."Architectures.xcspec"}' "$specspath/Architectures.xcspec"
|
||||
cp '${plists."PackageTypes.xcspec"}' "$specspath/PackageTypes.xcspec"
|
||||
cp '${plists."ProductTypes.xcspec"}' "$specspath/ProductTypes.xcspec"
|
||||
|
||||
mkdir -p "$out/usr/bin"
|
||||
ln -s '${xcbuild.xcrun}/bin/xcrun' "$out/usr/bin/xcrun"
|
||||
|
||||
# Include `libtool` in the toolchain, so `xcrun -find libtool` can find it without requiring `cctools.libtool`
|
||||
# as a `nativeBuildInput`.
|
||||
mkdir -p "$toolchainsPath/usr/bin"
|
||||
if [ -e '${cctools.libtool}/bin/${stdenv.cc.targetPrefix}libtool' ]; then
|
||||
ln -s '${cctools.libtool}/bin/${stdenv.cc.targetPrefix}libtool' "$toolchainsPath/usr/bin/libtool"
|
||||
fi
|
||||
|
||||
# Include additional binutils required by some packages (such as Chromium).
|
||||
for tool in lipo nm otool size strip; do
|
||||
if [ -e '${darwin.binutils-unwrapped}/bin/${stdenv.cc.targetPrefix}'$tool ]; then
|
||||
ln -s '${darwin.binutils-unwrapped}/bin/${stdenv.cc.targetPrefix}'$tool "$toolchainsPath/usr/bin/$tool"
|
||||
fi
|
||||
done
|
||||
'';
|
||||
}
|
||||
24
nix/apple-sdk/common/remove-disallowed-packages.nix
Normal file
24
nix/apple-sdk/common/remove-disallowed-packages.nix
Normal file
@@ -0,0 +1,24 @@
|
||||
let
|
||||
disallowedPackages = builtins.fromJSON (builtins.readFile ../metadata/disallowed-packages.json);
|
||||
in
|
||||
|
||||
{ lib
|
||||
, jq
|
||||
, stdenv
|
||||
,
|
||||
}:
|
||||
|
||||
self: super: {
|
||||
# Remove headers and stubs for packages that are available in nixpkgs.
|
||||
buildPhase = super.buildPhase or "" + ''
|
||||
${lib.concatMapStringsSep "\n" (
|
||||
pkg:
|
||||
lib.concatLines (
|
||||
[ ''echo "Removing headers and libraries from ${pkg.package}"'' ]
|
||||
++ (map (header: "rm -rf -- usr/include/${header}") pkg.headers or [ ])
|
||||
++ (map (framework: "rm -rf -- System/Library/Frameworks/${framework}") pkg.frameworks or [ ])
|
||||
++ (map (library: "rm -rf -- usr/lib/${library}") pkg.libraries or [ ])
|
||||
)
|
||||
) disallowedPackages}
|
||||
'';
|
||||
}
|
||||
9
nix/apple-sdk/common/run-build-phase-hooks.nix
Normal file
9
nix/apple-sdk/common/run-build-phase-hooks.nix
Normal file
@@ -0,0 +1,9 @@
|
||||
{}:
|
||||
|
||||
self: super: {
|
||||
buildPhase = ''
|
||||
runHook preBuild
|
||||
${super.buildPhase or ""}
|
||||
runHook postBuild
|
||||
'';
|
||||
}
|
||||
536
nix/apple-sdk/metadata/apple-oss-lockfile.json
Normal file
536
nix/apple-sdk/metadata/apple-oss-lockfile.json
Normal file
@@ -0,0 +1,536 @@
|
||||
{
|
||||
"14.4": {
|
||||
"CarbonHeaders": {
|
||||
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
|
||||
"version": "18.1"
|
||||
},
|
||||
"CommonCrypto": {
|
||||
"hash": "sha256-/VoOR9wJuKnmGE1CWGGXxX8SpmALHnEooNTa3QM+ITc=",
|
||||
"version": "600028.100.1"
|
||||
},
|
||||
"IOAudioFamily": {
|
||||
"hash": "sha256-VSk3jvsITJugtL67Qt0m4qJ879i7Fj6B/NGBFVCwpiU=",
|
||||
"version": "540.3"
|
||||
},
|
||||
"IOBDStorageFamily": {
|
||||
"hash": "sha256-UgLMsQBe1QLzlbScmPmASBN7VH4YBmNOUX2CEDezjmE=",
|
||||
"version": "22"
|
||||
},
|
||||
"IOCDStorageFamily": {
|
||||
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
|
||||
"version": "61"
|
||||
},
|
||||
"IODVDStorageFamily": {
|
||||
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
|
||||
"version": "45"
|
||||
},
|
||||
"IOFWDVComponents": {
|
||||
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
|
||||
"version": "208"
|
||||
},
|
||||
"IOFireWireAVC": {
|
||||
"hash": "sha256-IUytBKhhCgg0vtI+7q8d5kxpOUgO3tQD7TMy++jrorc=",
|
||||
"version": "431"
|
||||
},
|
||||
"IOFireWireFamily": {
|
||||
"hash": "sha256-W0KOF4hkA7kFOnL1ThAeFU/YlhFVqoqk9uzGjcBppX8=",
|
||||
"version": "487"
|
||||
},
|
||||
"IOFireWireSBP2": {
|
||||
"hash": "sha256-bItnRQIaGUxMyiU0q+4N8e5+jYiDEOUPmsrKhBFXvok=",
|
||||
"version": "445"
|
||||
},
|
||||
"IOFireWireSerialBusProtocolTransport": {
|
||||
"hash": "sha256-P7egeaD9SSa+YyrIRzM44gILKbIL7vezXK3M6q3MBOI=",
|
||||
"version": "260"
|
||||
},
|
||||
"IOGraphics": {
|
||||
"hash": "sha256-Ag37fd3tZJLXLVq1yzHOCWGOYYfwwTkC8hnvNaTEaWg=",
|
||||
"version": "598"
|
||||
},
|
||||
"IOHIDFamily": {
|
||||
"hash": "sha256-fmYTJsquAOBwzsgRmqPyjSJJi1hGcfnMmqLIcTe8W1s=",
|
||||
"version": "2031.100.16"
|
||||
},
|
||||
"IOKitUser": {
|
||||
"hash": "sha256-1bqRiLvyr2GQfbWwhXHXXIOtIka9YDw5GbKV6bd2k4k=",
|
||||
"version": "100076.101.1"
|
||||
},
|
||||
"IONetworkingFamily": {
|
||||
"hash": "sha256-J3cLeWKrQ8ypIaqgwRH9eU5JbjEDBVoezj3a2Lvwu5k=",
|
||||
"version": "177"
|
||||
},
|
||||
"IOSerialFamily": {
|
||||
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
|
||||
"version": "93"
|
||||
},
|
||||
"IOStorageFamily": {
|
||||
"hash": "sha256-cllpJX11c3CX8zEYdOT2TC63sx7NUAHh33yRHhrG2Ro=",
|
||||
"version": "315"
|
||||
},
|
||||
"IOUSBFamily": {
|
||||
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
|
||||
"version": "630.4.5"
|
||||
},
|
||||
"Libc": {
|
||||
"hash": "sha256-fxBM4KbPwQNVEJl7PCKP+1nUk9Oce/O2+0lVBxyngew=",
|
||||
"version": "1592.100.35"
|
||||
},
|
||||
"Libinfo": {
|
||||
"hash": "sha256-zZr6Mmou8Q+G6/wS+k0k7R+XirB94TNCUGS5dhi96ZE=",
|
||||
"version": "583.0.1"
|
||||
},
|
||||
"Libm": {
|
||||
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
|
||||
"version": "2026"
|
||||
},
|
||||
"Libnotify": {
|
||||
"hash": "sha256-7X+6S3C7ZOTXJUeDXOOg5EmoZyLZvtE06x3Is0TGgSU=",
|
||||
"version": "317.100.2"
|
||||
},
|
||||
"Librpcsvc": {
|
||||
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
|
||||
"version": "31"
|
||||
},
|
||||
"Libsystem": {
|
||||
"hash": "sha256-HsItciWrwyXujQ2hwqzv0JKOkkuynXYIqejLAEPJbMc=",
|
||||
"version": "1345.100.2"
|
||||
},
|
||||
"OpenDirectory": {
|
||||
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
|
||||
"version": "146"
|
||||
},
|
||||
"Security": {
|
||||
"hash": "sha256-NgTGbaw5JkpboDQpt1fSgUr9NYGS+bIOrEMQX7mLAME=",
|
||||
"version": "61123.100.169"
|
||||
},
|
||||
"architecture": {
|
||||
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
|
||||
"version": "282"
|
||||
},
|
||||
"configd": {
|
||||
"hash": "sha256-+3xesYxqfsNjWCW3T87OA7+Z1hBqmGEh/I8kP8Ajbso=",
|
||||
"version": "1300.100.9"
|
||||
},
|
||||
"copyfile": {
|
||||
"hash": "sha256-rSCTgzdHr7QmnPk9rJ9P4fOAolnEQv8PHfgAY+qA0s4=",
|
||||
"version": "196.100.4"
|
||||
},
|
||||
"dtrace": {
|
||||
"hash": "sha256-04Q35rCKnM5Csv5poFJKpK0VplWq4hvy251/Cb2Kl80=",
|
||||
"version": "401.100.3"
|
||||
},
|
||||
"dyld": {
|
||||
"hash": "sha256-6P/Da6xP19vmaCROoYv9pl7DaW3/U+qZBJT8PD33bn0=",
|
||||
"version": "1160.6"
|
||||
},
|
||||
"eap8021x": {
|
||||
"hash": "sha256-Ky6KSlJhyX1NRufGhVBcp+ZFmqYrAxwC/5QvJhC2PhU=",
|
||||
"version": "354.100.3"
|
||||
},
|
||||
"hfs": {
|
||||
"hash": "sha256-+YUVOttZU7C8I14CC6t3ZH2KxAjjTA2nB0y5bPgLxZM=",
|
||||
"version": "650.0.2"
|
||||
},
|
||||
"launchd": {
|
||||
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
|
||||
"version": "842.1.4"
|
||||
},
|
||||
"libclosure": {
|
||||
"hash": "sha256-M/jnIHzKYvdFCO0tJ1JXiD/UcZtJhLIoulaCQQUbn30=",
|
||||
"version": "90"
|
||||
},
|
||||
"libdispatch": {
|
||||
"hash": "sha256-igqIA5DMVHjG30WMHZZpYY7LRM9hZyMWItD+UxeTehY=",
|
||||
"version": "1477.100.9"
|
||||
},
|
||||
"libmalloc": {
|
||||
"hash": "sha256-Sh4/z7lGWRMldOPURkP5vLOAb5Ou6AUsVJEWz9wk9hI=",
|
||||
"version": "521.100.59"
|
||||
},
|
||||
"libplatform": {
|
||||
"hash": "sha256-gojt3sWOr7XO2yYI/B1CmNLTPFieSfoNtlOgQahOCok=",
|
||||
"version": "316.100.10"
|
||||
},
|
||||
"libpthread": {
|
||||
"hash": "sha256-phjfN8+IU8ibPsflR6LktnSi3giy89ghI+cFyrhiQNo=",
|
||||
"version": "519.101.1"
|
||||
},
|
||||
"mDNSResponder": {
|
||||
"hash": "sha256-0ECbWeMnIRTsi03BeBEe5boyR/84JJPbxzPQze8hHSA=",
|
||||
"version": "2200.100.94.0.2"
|
||||
},
|
||||
"objc4": {
|
||||
"hash": "sha256-eUVSpbyTEOMEdHoxSv6lZIZwB+cW/YWIaTZTcHgGOjo=",
|
||||
"version": "912.3"
|
||||
},
|
||||
"ppp": {
|
||||
"hash": "sha256-8+QUA79sHf85yvGSPE9qCmGsrZDT3NZnbgZVroJw/Hg=",
|
||||
"version": "1016"
|
||||
},
|
||||
"removefile": {
|
||||
"hash": "sha256-L6I0u8S3h3uV1veKA5HvkSebbBCd78ymlf//KWbebZo=",
|
||||
"version": "70.100.4"
|
||||
},
|
||||
"xnu": {
|
||||
"hash": "sha256-j5Ep1RX5DTJqTGszrF4d/JtzUqZ6nA6XoExqcIQ0RVQ=",
|
||||
"version": "10063.101.15"
|
||||
}
|
||||
},
|
||||
"15.5": {
|
||||
"CarbonHeaders": {
|
||||
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
|
||||
"version": "18.1"
|
||||
},
|
||||
"CommonCrypto": {
|
||||
"hash": "sha256-+qAwL6+s7di9cX/qXtapLkjCFoDuZaSYltRJEG4qekM=",
|
||||
"version": "600035"
|
||||
},
|
||||
"IOAudioFamily": {
|
||||
"hash": "sha256-VSk3jvsITJugtL67Qt0m4qJ879i7Fj6B/NGBFVCwpiU=",
|
||||
"version": "600.2"
|
||||
},
|
||||
"IOBDStorageFamily": {
|
||||
"hash": "sha256-s8hTwX0jq2iPULfBLUwpzqtszWuvJrrLGbmrKa/fY4U=",
|
||||
"version": "24"
|
||||
},
|
||||
"IOCDStorageFamily": {
|
||||
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
|
||||
"version": "62"
|
||||
},
|
||||
"IODVDStorageFamily": {
|
||||
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
|
||||
"version": "46"
|
||||
},
|
||||
"IOFWDVComponents": {
|
||||
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
|
||||
"version": "208"
|
||||
},
|
||||
"IOFireWireAVC": {
|
||||
"hash": "sha256-qR9lSTa7PN5Z9Nis4tfuXlcZGMIU48dete/NPD0UBbE=",
|
||||
"version": "434"
|
||||
},
|
||||
"IOFireWireFamily": {
|
||||
"hash": "sha256-hmErAXjLWIelqJaCrB8J4IiIxyB7S6EHFY+AY9YhmKQ=",
|
||||
"version": "490"
|
||||
},
|
||||
"IOFireWireSBP2": {
|
||||
"hash": "sha256-Xk+PDnUaO9q46nQwHwTKf/QXtGclfs0wTWiUbcV7e4s=",
|
||||
"version": "452"
|
||||
},
|
||||
"IOFireWireSerialBusProtocolTransport": {
|
||||
"hash": "sha256-P7egeaD9SSa+YyrIRzM44gILKbIL7vezXK3M6q3MBOI=",
|
||||
"version": "261"
|
||||
},
|
||||
"IOGraphics": {
|
||||
"hash": "sha256-iysZE42mOKZbFxSZBNspaBTCRKEKK38DFGBxZWQxZxI=",
|
||||
"version": "599"
|
||||
},
|
||||
"IOHIDFamily": {
|
||||
"hash": "sha256-gEYPyjXgQ2ABGufCKPjmzMdNRLxhELkCvOURCokyTO4=",
|
||||
"version": "2115.100.21"
|
||||
},
|
||||
"IOKitUser": {
|
||||
"hash": "sha256-p32U+jHfwA/tqnjF4p1BmojghEXK8KxiflW3IHs2iIY=",
|
||||
"version": "100150.120.2"
|
||||
},
|
||||
"IONetworkingFamily": {
|
||||
"hash": "sha256-gZ7Dkk4Iu7AV9K2ioqSeJ1W7bTNxv77bmT18iv3ljLg=",
|
||||
"version": "185"
|
||||
},
|
||||
"IOSerialFamily": {
|
||||
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
|
||||
"version": "93"
|
||||
},
|
||||
"IOStorageFamily": {
|
||||
"hash": "sha256-/0H0tqWUWkgYigYypucbc7lOCFYDuukwF9fvLEOhwOk=",
|
||||
"version": "323"
|
||||
},
|
||||
"IOUSBFamily": {
|
||||
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
|
||||
"version": "630.4.5"
|
||||
},
|
||||
"Libc": {
|
||||
"hash": "sha256-nWDokN0Vr5pUyNGculnDOah9RNgHiWr3S13RSQLmZrc=",
|
||||
"version": "1698.100.8"
|
||||
},
|
||||
"Libinfo": {
|
||||
"hash": "sha256-UI5mGvzZ6BPafGYD6CrNAJAKjeJLB6urAS2lpB6X/Ec=",
|
||||
"version": "597"
|
||||
},
|
||||
"Libm": {
|
||||
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
|
||||
"version": "2026"
|
||||
},
|
||||
"Libnotify": {
|
||||
"hash": "sha256-GDYMVi1034f9empq0YOuumQp/BDJ7phTb0Zl4KTY9xg=",
|
||||
"version": "342"
|
||||
},
|
||||
"Librpcsvc": {
|
||||
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
|
||||
"version": "31"
|
||||
},
|
||||
"Libsystem": {
|
||||
"hash": "sha256-nawWJiu2IJ34ek5iOX6CrlqMzev7TuJpUkvDp30ZQ/U=",
|
||||
"version": "1351"
|
||||
},
|
||||
"OpenDirectory": {
|
||||
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
|
||||
"version": "146"
|
||||
},
|
||||
"Security": {
|
||||
"hash": "sha256-ZOrOOCk+hZbzDilzkihpQfsDpzV3Ul4zy6fpFRWUQHw=",
|
||||
"version": "61439.120.27"
|
||||
},
|
||||
"architecture": {
|
||||
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
|
||||
"version": "282"
|
||||
},
|
||||
"configd": {
|
||||
"hash": "sha256-ZdUq1SrOwB88Lx68ekrA4zeVsLDZz4TAJywNnF+uAzY=",
|
||||
"version": "1351.120.3"
|
||||
},
|
||||
"copyfile": {
|
||||
"hash": "sha256-rLqT6e44W2ohgwUXREmiOyJBYCrV3gRLbtVnbUq60xc=",
|
||||
"version": "221.121.1"
|
||||
},
|
||||
"dtrace": {
|
||||
"hash": "sha256-iNEZyxK3DmEwO3gzrfvCaVZSEuuOMQm5IG/6FodPNdI=",
|
||||
"version": "411"
|
||||
},
|
||||
"dyld": {
|
||||
"hash": "sha256-4OOghgUYyMJbsTe96fiWCndTJ1BS94rK9v6Kqn/ooYs=",
|
||||
"version": "1285.19"
|
||||
},
|
||||
"eap8021x": {
|
||||
"hash": "sha256-Kx/wwnt108hDm0qQPyTNbZ8KoHkD5m7L4yb5qjSuQjI=",
|
||||
"version": "365.120.2"
|
||||
},
|
||||
"hfs": {
|
||||
"hash": "sha256-5/3Ycp3cKqlgAl1kjBmbF5tFlfJYQS5rbrbk4SS66b8=",
|
||||
"version": "683.120.3"
|
||||
},
|
||||
"launchd": {
|
||||
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
|
||||
"version": "842.1.4"
|
||||
},
|
||||
"libclosure": {
|
||||
"hash": "sha256-pvwfcbeEJmTEPdt6/lgVswiabLRG+sMN6VT5FwG7C4Q=",
|
||||
"version": "96"
|
||||
},
|
||||
"libdispatch": {
|
||||
"hash": "sha256-jTp2DolOOCQPBt1HRotkmPnKgQ2LGgniEqeHoM+vlKg=",
|
||||
"version": "1521.120.4"
|
||||
},
|
||||
"libmalloc": {
|
||||
"hash": "sha256-d9AVHSYTqHDlgctv8Hh4HAYW53MJelj4F8LWPsjrsws=",
|
||||
"version": "715.120.13"
|
||||
},
|
||||
"libplatform": {
|
||||
"hash": "sha256-gpijoTMvdkM0PdG8gyIllOJlh/MtTc4ro9ODDAhN6gM=",
|
||||
"version": "349"
|
||||
},
|
||||
"libpthread": {
|
||||
"hash": "sha256-N+MMXdbthsxauTTfZ5ElUs39dVH+Chn1yyU6pObZpkU=",
|
||||
"version": "536"
|
||||
},
|
||||
"mDNSResponder": {
|
||||
"hash": "sha256-ILx12PRxj/+VqfpCCErJFEJXFI9yzTh4g+FK0UCenIE=",
|
||||
"version": "2600.120.12"
|
||||
},
|
||||
"objc4": {
|
||||
"hash": "sha256-DMxa25gXjKCkiDnVJ/8SyJUjaBlmBGABg8EfCHcmTj0=",
|
||||
"version": "940.4"
|
||||
},
|
||||
"ppp": {
|
||||
"hash": "sha256-8+QUA79sHf85yvGSPE9qCmGsrZDT3NZnbgZVroJw/Hg=",
|
||||
"version": "1016"
|
||||
},
|
||||
"removefile": {
|
||||
"hash": "sha256-Z5UD0mk/s80CQB0PZWDzSl2JWXmnVmwUvlNb28+hR3k=",
|
||||
"version": "81"
|
||||
},
|
||||
"xnu": {
|
||||
"hash": "sha256-o4tCuCAIgAYg/Li3wTs12mVWr5C/4vbwu1zi+kJ9d6w=",
|
||||
"version": "11417.121.6"
|
||||
}
|
||||
},
|
||||
"26.0": {
|
||||
"CarbonHeaders": {
|
||||
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
|
||||
"version": "18.1"
|
||||
},
|
||||
"CommonCrypto": {
|
||||
"hash": "sha256-+qAwL6+s7di9cX/qXtapLkjCFoDuZaSYltRJEG4qekM=",
|
||||
"version": "600035"
|
||||
},
|
||||
"IOAudioFamily": {
|
||||
"hash": "sha256-A3iiAjjP29VdjMj40tLS5Q/ni4qeh9bBpnmNzeG2pIY=",
|
||||
"version": "700.2"
|
||||
},
|
||||
"IOBDStorageFamily": {
|
||||
"hash": "sha256-OcQUJ3nEfrpvWX/npnedJ4PECIGWFSLiM0PKoiH911w=",
|
||||
"version": "26"
|
||||
},
|
||||
"IOCDStorageFamily": {
|
||||
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
|
||||
"version": "62"
|
||||
},
|
||||
"IODVDStorageFamily": {
|
||||
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
|
||||
"version": "46"
|
||||
},
|
||||
"IOFWDVComponents": {
|
||||
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
|
||||
"version": "208"
|
||||
},
|
||||
"IOFireWireAVC": {
|
||||
"hash": "sha256-qR9lSTa7PN5Z9Nis4tfuXlcZGMIU48dete/NPD0UBbE=",
|
||||
"version": "436"
|
||||
},
|
||||
"IOFireWireFamily": {
|
||||
"hash": "sha256-hmErAXjLWIelqJaCrB8J4IiIxyB7S6EHFY+AY9YhmKQ=",
|
||||
"version": "492"
|
||||
},
|
||||
"IOFireWireSBP2": {
|
||||
"hash": "sha256-Xk+PDnUaO9q46nQwHwTKf/QXtGclfs0wTWiUbcV7e4s=",
|
||||
"version": "454"
|
||||
},
|
||||
"IOFireWireSerialBusProtocolTransport": {
|
||||
"hash": "sha256-cM/VFhVWNVwdJYk+mme0UYttQd7eJwd7Hlo7KNRyHY0=",
|
||||
"version": "262"
|
||||
},
|
||||
"IOGraphics": {
|
||||
"hash": "sha256-iysZE42mOKZbFxSZBNspaBTCRKEKK38DFGBxZWQxZxI=",
|
||||
"version": "599"
|
||||
},
|
||||
"IOHIDFamily": {
|
||||
"hash": "sha256-YLnabX90g4Q8LxjwVuJF6KODCDxychWV+VJaNG9d8fI=",
|
||||
"version": "2222.0.24"
|
||||
},
|
||||
"IOKitUser": {
|
||||
"hash": "sha256-ngwi8YMUqE0q8j7Lr5cqJwi2V+IDu3ie3bduotHIUJU=",
|
||||
"version": "100222.0.4"
|
||||
},
|
||||
"IONetworkingFamily": {
|
||||
"hash": "sha256-ZF5ML41Y1l1liQn32qTkcl4mMvx9Xdizb9VgvTzVTL4=",
|
||||
"version": "186"
|
||||
},
|
||||
"IOSerialFamily": {
|
||||
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
|
||||
"version": "93"
|
||||
},
|
||||
"IOStorageFamily": {
|
||||
"hash": "sha256-1FKSF622qeXPGngA3UmQ2M/IU1pdlMoYBPbXytUFDaQ=",
|
||||
"version": "331"
|
||||
},
|
||||
"IOUSBFamily": {
|
||||
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
|
||||
"version": "630.4.5"
|
||||
},
|
||||
"Libc": {
|
||||
"hash": "sha256-k+HQ+qgye0ORFm0hU8WzE4ysbbEoFZ7wcbVl5giDH/E=",
|
||||
"version": "1725.0.11"
|
||||
},
|
||||
"Libinfo": {
|
||||
"hash": "sha256-4InBEPi0n2EMo/8mIBib1Im4iTKRcRJ4IlAcLCigVGk=",
|
||||
"version": "600"
|
||||
},
|
||||
"Libm": {
|
||||
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
|
||||
"version": "2026"
|
||||
},
|
||||
"Libnotify": {
|
||||
"hash": "sha256-p8cJZlBYOFmI1NDHXGYjgcv8z9Ldc1amZuYlxxJfeVY=",
|
||||
"version": "344.0.1"
|
||||
},
|
||||
"Librpcsvc": {
|
||||
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
|
||||
"version": "31"
|
||||
},
|
||||
"Libsystem": {
|
||||
"hash": "sha256-/NlSwPaoTVx+bl9hYsfz3C5MuLdqGv4vdAh0KDbDKmY=",
|
||||
"version": "1356"
|
||||
},
|
||||
"OpenDirectory": {
|
||||
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
|
||||
"version": "146"
|
||||
},
|
||||
"Security": {
|
||||
"hash": "sha256-oxOvZsDoNYZNiWf+MASHrR4Q2o5oaqvK2We51hH7CO8=",
|
||||
"version": "61901.0.87.0.1"
|
||||
},
|
||||
"architecture": {
|
||||
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
|
||||
"version": "282"
|
||||
},
|
||||
"configd": {
|
||||
"hash": "sha256-58or+OQP788UgQKO7Y8k8pY/enaSqH971ks7xCPu8fA=",
|
||||
"version": "1385.0.7"
|
||||
},
|
||||
"copyfile": {
|
||||
"hash": "sha256-I9uDi5BDQKa7mO3XpHxv0d6PiROW2ueZ3vGfrsG0OJo=",
|
||||
"version": "230.0.1.0.1"
|
||||
},
|
||||
"dtrace": {
|
||||
"hash": "sha256-5HpH6Cg8vWWzOX5ADD//izKDvqGnzV05Giju8lmGeyA=",
|
||||
"version": "413"
|
||||
},
|
||||
"dyld": {
|
||||
"hash": "sha256-jzoFLwbms0rUwzyjYif/r6Rmr4kyn+as/bhc4paEPeY=",
|
||||
"version": "1323.3"
|
||||
},
|
||||
"eap8021x": {
|
||||
"hash": "sha256-17bseWT4OWMA8hF+YSDDjxhVyJpbpP2xwv8dGti1YoM=",
|
||||
"version": "368.0.3"
|
||||
},
|
||||
"hfs": {
|
||||
"hash": "sha256-OkgqZ03gwn2hTuHxZrPDmQOrY4Dwu7MrX+BfG+PTgvE=",
|
||||
"version": "704.0.3.0.2"
|
||||
},
|
||||
"launchd": {
|
||||
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
|
||||
"version": "842.1.4"
|
||||
},
|
||||
"libclosure": {
|
||||
"hash": "sha256-pvwfcbeEJmTEPdt6/lgVswiabLRG+sMN6VT5FwG7C4Q=",
|
||||
"version": "96"
|
||||
},
|
||||
"libdispatch": {
|
||||
"hash": "sha256-L0+Ho9dAlMXVpqFEGIcIMsJc0gULckRulUImNEZe5MU=",
|
||||
"version": "1542.0.4"
|
||||
},
|
||||
"libmalloc": {
|
||||
"hash": "sha256-482hgm1ESr3LWC/JhuQNGNu9smsa2Eap49/eH+YNAio=",
|
||||
"version": "792.1.1"
|
||||
},
|
||||
"libplatform": {
|
||||
"hash": "sha256-wGZ2Im81mRXx6epgj/tbOJpg89CEbAr0Z8oFEpkyNMU=",
|
||||
"version": "359.1.2"
|
||||
},
|
||||
"libpthread": {
|
||||
"hash": "sha256-VuMpQjxuMsdHsFq0q6QIWSWi88gVF2jNzIfti20Gkbw=",
|
||||
"version": "539"
|
||||
},
|
||||
"mDNSResponder": {
|
||||
"hash": "sha256-iRqCpPAQDRjgRbRz3s6q2oyzq6xo+w4FTBai79104Zo=",
|
||||
"version": "2881.0.25"
|
||||
},
|
||||
"objc4": {
|
||||
"hash": "sha256-Nlgr36yLvGkUJIEFQ5w8FAB0r2syEsRTw0KuUShNT8E=",
|
||||
"version": "950"
|
||||
},
|
||||
"ppp": {
|
||||
"hash": "sha256-FzHZ05o7JxwgTqz0e3D68b/DiLu2x2ErzGMh0U78fLo=",
|
||||
"version": "1020.1.1"
|
||||
},
|
||||
"removefile": {
|
||||
"hash": "sha256-Z5UD0mk/s80CQB0PZWDzSl2JWXmnVmwUvlNb28+hR3k=",
|
||||
"version": "84"
|
||||
},
|
||||
"xnu": {
|
||||
"hash": "sha256-Cuf7kPtsn4CPXqyZmxVsJlA5i+Ikryp8ezJyGrvT63c=",
|
||||
"version": "12377.1.9"
|
||||
}
|
||||
}
|
||||
}
|
||||
533
nix/apple-sdk/metadata/disallowed-packages.json
Normal file
533
nix/apple-sdk/metadata/disallowed-packages.json
Normal file
@@ -0,0 +1,533 @@
|
||||
[
|
||||
{
|
||||
"package": "apache",
|
||||
"headers": [
|
||||
"apache2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "apr",
|
||||
"headers": [
|
||||
"apr-1"
|
||||
],
|
||||
"libraries": [
|
||||
"libapr-1.*",
|
||||
"libaprutil-1.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "boringssl",
|
||||
"libraries": [
|
||||
"libboringssl.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "bzip2",
|
||||
"headers": [
|
||||
"bzlib.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libbz2.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "corecrypto",
|
||||
"libraries": [
|
||||
"system/libcorecrypto*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "Csu",
|
||||
"libraries": [
|
||||
"*.o"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "cups",
|
||||
"headers": [
|
||||
"cups"
|
||||
],
|
||||
"libraries": [
|
||||
"libcups*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "curl",
|
||||
"headers": [
|
||||
"curl"
|
||||
],
|
||||
"libraries": [
|
||||
"libcurl.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "cyrus_sasl",
|
||||
"headers": [
|
||||
"sasl"
|
||||
],
|
||||
"libraries": [
|
||||
"libsasl*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "editline",
|
||||
"headers": [
|
||||
"editline.h",
|
||||
"editline"
|
||||
],
|
||||
"libraries": [
|
||||
"libedit.*",
|
||||
"libeditline.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "html-tidy",
|
||||
"headers": [
|
||||
"tidy*"
|
||||
],
|
||||
"libraries": [
|
||||
"libtidy.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "hunspell",
|
||||
"headers": [
|
||||
"hunspell"
|
||||
],
|
||||
"libraries": [
|
||||
"libhunspell*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "icu",
|
||||
"headers": [
|
||||
"unicode"
|
||||
],
|
||||
"libraries": [
|
||||
"libicucore.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libarchive",
|
||||
"headers": [
|
||||
"archive.h",
|
||||
"archive_entry.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libarchive.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libc++",
|
||||
"headers": [
|
||||
"c++",
|
||||
"cxxabi.h",
|
||||
"__cxxabi_config.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libc++*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "ld64",
|
||||
"libraries": [
|
||||
"libcodedirectory.*",
|
||||
"libcodedirectory_static.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "expat",
|
||||
"headers": [
|
||||
"expat.h",
|
||||
"expat_config.h",
|
||||
"expat_external.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libexpat.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libffi",
|
||||
"headers": [
|
||||
"ffi*"
|
||||
],
|
||||
"libraries": [
|
||||
"libffi*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libgcc",
|
||||
"libraries": [
|
||||
"libgcc*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libiconv",
|
||||
"headers": [
|
||||
"iconv.h",
|
||||
"libcharset.h",
|
||||
"localcharset.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libcharset.*",
|
||||
"libiconv.*",
|
||||
"i18n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libiodbc",
|
||||
"libraries": [
|
||||
"libiodbc*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libkrb4",
|
||||
"libraries": [
|
||||
"libkrb4.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libkrb5",
|
||||
"headers": [
|
||||
"com_err.h",
|
||||
"gssapi",
|
||||
"gssapi.h",
|
||||
"gssrpc",
|
||||
"kadm5",
|
||||
"kdb.h",
|
||||
"krad.h",
|
||||
"krb5",
|
||||
"krb5.h",
|
||||
"profile.h",
|
||||
"verto-module.h",
|
||||
"verto.h"
|
||||
],
|
||||
"libraries": [
|
||||
"krb5",
|
||||
"libcom_err.*",
|
||||
"libgssapi_krb5.*",
|
||||
"libgssrpc.*",
|
||||
"libk5crypto.*",
|
||||
"libkadm5clnt.*",
|
||||
"libkadm5clnt_mit.*",
|
||||
"libkadm5srv.*",
|
||||
"libkadm5srv_mit.*",
|
||||
"libkdb5.*",
|
||||
"libkrad.*",
|
||||
"libkrb5*",
|
||||
"libkrb5support.*",
|
||||
"libverto.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libpcap",
|
||||
"headers": [
|
||||
"pcap*"
|
||||
],
|
||||
"libraries": [
|
||||
"libpcap.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libresolv",
|
||||
"headers": [
|
||||
"arpa/nameser.h",
|
||||
"arpa/nameser_compat.h",
|
||||
"dns.h",
|
||||
"dns_util.h",
|
||||
"nameser.h",
|
||||
"resolv.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libresolv.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libstdc++",
|
||||
"libraries": [
|
||||
"libstdc++.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libsbuf",
|
||||
"headers": [
|
||||
"usbuf.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libsbuf.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libtermcap",
|
||||
"headers": [
|
||||
"termcap.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libtermcap.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libutil",
|
||||
"headers": [
|
||||
"libutil.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libutil.*",
|
||||
"libutil1.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libxml2",
|
||||
"headers": [
|
||||
"libxml",
|
||||
"libxml2"
|
||||
],
|
||||
"libraries": [
|
||||
"libxml2.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libxo",
|
||||
"headers": [
|
||||
"libxo"
|
||||
],
|
||||
"libraries": [
|
||||
"libxo.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "libxslt",
|
||||
"headers": [
|
||||
"libexslt",
|
||||
"libxslt"
|
||||
],
|
||||
"libraries": [
|
||||
"libexslt.*",
|
||||
"libxslt.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "liby",
|
||||
"libraries": [
|
||||
"liby.a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "marisa-trie",
|
||||
"libraries": [
|
||||
"libmarisa.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "ncurses",
|
||||
"headers": [
|
||||
"curses*",
|
||||
"cursslk.h",
|
||||
"eti.h",
|
||||
"etip.h",
|
||||
"form.h",
|
||||
"menu.h",
|
||||
"nc_tparm.h",
|
||||
"ncurses*",
|
||||
"panel.h",
|
||||
"term.h",
|
||||
"term_entry.h",
|
||||
"termcap.h",
|
||||
"tic.h",
|
||||
"unctrl.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libcurses.*",
|
||||
"libform.*",
|
||||
"libformw.*",
|
||||
"libmenu.*",
|
||||
"libmenuw.*",
|
||||
"libncurses.*",
|
||||
"libncursesw.*",
|
||||
"libpanel.*",
|
||||
"libpanelw.*",
|
||||
"libtinfo.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "net-snmp",
|
||||
"headers": [
|
||||
"net-snmp"
|
||||
],
|
||||
"libraries": [
|
||||
"libnetsnmp*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "nghttp",
|
||||
"libraries": [
|
||||
"lib*nghttp2.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "openblas",
|
||||
"headers": [
|
||||
"cblas.h",
|
||||
"f77blas.h",
|
||||
"lapack.h",
|
||||
"lapacke.h",
|
||||
"lapacke_config.h",
|
||||
"lapacke_mangling.h",
|
||||
"lapacke_utils.h",
|
||||
"openblas_config.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libblas.*",
|
||||
"libcblas.*",
|
||||
"libclapack.*",
|
||||
"libf77lapack.*",
|
||||
"liblapack.*",
|
||||
"liblapacke.*",
|
||||
"libopenblas.*",
|
||||
"libopenblas.*",
|
||||
"libopenblasp*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "openldap",
|
||||
"libraries": [
|
||||
"liblber.*",
|
||||
"liblber_r.*",
|
||||
"libldap.*",
|
||||
"libldap_r.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "openpam",
|
||||
"headers": [
|
||||
"security"
|
||||
],
|
||||
"libraries": [
|
||||
"libpam.*",
|
||||
"pam_*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "pcre",
|
||||
"headers": [
|
||||
"pcre.h",
|
||||
"pcreposix.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libpcre.*",
|
||||
"libpcre2*",
|
||||
"libpcreposix.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "php",
|
||||
"headers": [
|
||||
"php"
|
||||
],
|
||||
"libraries": [
|
||||
"php"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "postgresql",
|
||||
"libraries": [
|
||||
"libecpg*",
|
||||
"libpg*",
|
||||
"libpq*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "python",
|
||||
"headers": [
|
||||
"python*"
|
||||
],
|
||||
"frameworks": [
|
||||
"Python.framework"
|
||||
],
|
||||
"libraries": [
|
||||
"libpython*",
|
||||
"python*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "readline",
|
||||
"headers": [
|
||||
"readline"
|
||||
],
|
||||
"libraries": [
|
||||
"libhistory.*",
|
||||
"libreadline.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "ruby",
|
||||
"frameworks": [
|
||||
"Ruby.framework"
|
||||
],
|
||||
"libraries": [
|
||||
"libruby.*",
|
||||
"ruby"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "sqlite3",
|
||||
"headers": [
|
||||
"sqlite3.h",
|
||||
"sqlite3ext.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libsqlite3.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "swift",
|
||||
"libraries": [
|
||||
"swift/shims"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "tcl",
|
||||
"headers": [
|
||||
"tcl*",
|
||||
"tk*"
|
||||
],
|
||||
"frameworks": [
|
||||
"Tcl.framework",
|
||||
"Tk.framework"
|
||||
],
|
||||
"libraries": [
|
||||
"libtcl*",
|
||||
"libtk*",
|
||||
"tclConfig.sh",
|
||||
"tkConfig.sh"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "xar",
|
||||
"headers": [
|
||||
"xar"
|
||||
],
|
||||
"libraries": [
|
||||
"libxar.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "xz",
|
||||
"headers": [
|
||||
"lzma*"
|
||||
],
|
||||
"libraries": [
|
||||
"liblzma.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"package": "zlib",
|
||||
"headers": [
|
||||
"zconf.h",
|
||||
"zlib.h"
|
||||
],
|
||||
"libraries": [
|
||||
"libz.*"
|
||||
]
|
||||
}
|
||||
]
|
||||
26
nix/apple-sdk/metadata/versions.json
Normal file
26
nix/apple-sdk/metadata/versions.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"14": {
|
||||
"urls": [
|
||||
"https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg",
|
||||
"https://web.archive.org/web/20250211001355/https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg"
|
||||
],
|
||||
"version": "14.4",
|
||||
"hash": "sha256-QozDiwY0Czc0g45vPD7G4v4Ra+3DujCJbSads3fJjjM="
|
||||
},
|
||||
"15": {
|
||||
"urls": [
|
||||
"https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg",
|
||||
"https://web.archive.org/web/20250530132510/https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg"
|
||||
],
|
||||
"version": "15.5",
|
||||
"hash": "sha256-HBiSJuw1XBUK5R/8Sj65c3rftSEvQl/O9ZZVp/g1Amo="
|
||||
},
|
||||
"26": {
|
||||
"urls": [
|
||||
"https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg",
|
||||
"https://web.archive.org/web/20250915230423/https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg"
|
||||
],
|
||||
"version": "26.2",
|
||||
"hash": "sha256-hXRlMieVv0smna5uiWRwq87IWOaPWtAjAldbi+wQXcw="
|
||||
}
|
||||
}
|
||||
110
nix/apple-sdk/package.nix
Normal file
110
nix/apple-sdk/package.nix
Normal file
@@ -0,0 +1,110 @@
|
||||
let
|
||||
sdkVersions = builtins.fromJSON (builtins.readFile ./metadata/versions.json);
|
||||
in
|
||||
|
||||
{ lib
|
||||
, stdenv
|
||||
, stdenvNoCC
|
||||
, substitute
|
||||
, # Specifies the major version used for the SDK. Uses `hostPlatform.darwinSdkVersion` by default.
|
||||
darwinSdkMajorVersion ? lib.versions.major stdenv.hostPlatform.darwinSdkVersion
|
||||
, # Enabling bootstrap disables propagation. Defaults to `false` (meaning to propagate certain packages and `xcrun`)
|
||||
# except in stage0 of the Darwin stdenv bootstrap.
|
||||
enableBootstrap ? stdenv.name == "bootstrap-stage0-stdenv-darwin"
|
||||
, # Required by various phases
|
||||
callPackage
|
||||
,
|
||||
}:
|
||||
|
||||
let
|
||||
sdkInfo =
|
||||
sdkVersions.${darwinSdkMajorVersion}
|
||||
or (lib.throw "Unsupported SDK major version: ${darwinSdkMajorVersion}");
|
||||
sdkVersion = sdkInfo.version;
|
||||
|
||||
fetchSDK = callPackage ./common/fetch-sdk.nix { };
|
||||
|
||||
phases = lib.composeManyExtensions (
|
||||
[
|
||||
(callPackage ./common/add-core-symbolication.nix { })
|
||||
(callPackage ./common/derivation-options.nix { })
|
||||
(callPackage ./common/passthru-private-frameworks.nix { inherit sdkVersion; })
|
||||
(callPackage ./common/passthru-source-release-files.nix { inherit sdkVersion; })
|
||||
(callPackage ./common/remove-disallowed-packages.nix { })
|
||||
(callPackage ./common/process-stubs.nix { })
|
||||
]
|
||||
# Avoid infinite recursions by not propagating certain packages, so they can themselves build with the SDK.
|
||||
++ lib.optionals (!enableBootstrap) [
|
||||
(callPackage ./common/propagate-inputs.nix { })
|
||||
(callPackage ./common/propagate-xcrun.nix { inherit sdkVersion; })
|
||||
]
|
||||
# This has to happen last.
|
||||
++ [
|
||||
(callPackage ./common/run-build-phase-hooks.nix { })
|
||||
]
|
||||
);
|
||||
in
|
||||
stdenvNoCC.mkDerivation (
|
||||
lib.extends phases (finalAttrs: {
|
||||
pname = "apple-sdk";
|
||||
inherit (sdkInfo) version;
|
||||
|
||||
src = fetchSDK sdkInfo;
|
||||
|
||||
dontConfigure = true;
|
||||
|
||||
strictDeps = true;
|
||||
|
||||
setupHooks = [
|
||||
# `role.bash` is copied from `../build-support/setup-hooks/role.bash` due to the requirements not to reference
|
||||
# paths outside the package when it is in `by-name`. It needs to be kept in sync, but it fortunately does not
|
||||
# change often. Once `build-support` is available as a package (or some other mechanism), it should be changed
|
||||
# to whatever that replacement is.
|
||||
./setup-hooks/role.bash
|
||||
(substitute {
|
||||
src = ./setup-hooks/sdk-hook.sh;
|
||||
substitutions = [
|
||||
"--subst-var-by"
|
||||
"sdkVersion"
|
||||
(lib.escapeShellArgs (lib.splitVersion sdkVersion))
|
||||
];
|
||||
})
|
||||
];
|
||||
|
||||
installPhase =
|
||||
let
|
||||
sdkName = "MacOSX${lib.versions.majorMinor sdkVersion}.sdk";
|
||||
sdkMajor = lib.versions.major sdkVersion;
|
||||
in
|
||||
''
|
||||
runHook preInstall
|
||||
|
||||
mkdir -p "$sdkpath"
|
||||
|
||||
cp -rd . "$sdkpath/${sdkName}"
|
||||
ln -s "${sdkName}" "$sdkpath/MacOSX${sdkMajor}.sdk"
|
||||
ln -s "${sdkName}" "$sdkpath/MacOSX.sdk"
|
||||
|
||||
# Swift adds these locations to its search paths. Avoid spurious warnings by making sure they exist.
|
||||
mkdir -p "$platformPath/Developer/Library/Frameworks"
|
||||
mkdir -p "$platformPath/Developer/Library/PrivateFrameworks"
|
||||
mkdir -p "$platformPath/Developer/usr/lib"
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
passthru = {
|
||||
sdkroot = finalAttrs.finalPackage + "/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk";
|
||||
};
|
||||
|
||||
__structuredAttrs = true;
|
||||
|
||||
meta = {
|
||||
description = "Frameworks and libraries required for building packages on Darwin";
|
||||
homepage = "https://developer.apple.com";
|
||||
teams = [ lib.teams.darwin ];
|
||||
platforms = lib.platforms.darwin;
|
||||
badPlatforms = [ lib.systems.inspect.patterns.is32bit ];
|
||||
};
|
||||
})
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
From 6531da946949a94643e6d8424236174ae64fe0ca Mon Sep 17 00:00:00 2001
|
||||
From: Randy Eckenrode <randy@largeandhighquality.com>
|
||||
Date: Sat, 30 Sep 2023 18:02:39 -0400
|
||||
Subject: [PATCH 1/2] Add function definitions needed to build zlog in
|
||||
system_cmds
|
||||
|
||||
---
|
||||
CoreSymbolication.h | 10 +++++++---
|
||||
1 file changed, 7 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/CoreSymbolication.h b/CoreSymbolication.h
|
||||
index a413860..f3cf63f 100644
|
||||
--- a/CoreSymbolication.h
|
||||
+++ b/CoreSymbolication.h
|
||||
@@ -324,7 +324,9 @@ CSSymbolOwnerEditRelocations
|
||||
CSSymbolOwnerForeachRegion
|
||||
CSSymbolOwnerForeachRegionWithName
|
||||
CSSymbolOwnerForeachSection
|
||||
-CSSymbolOwnerForeachSegment
|
||||
+*/
|
||||
+void CSSymbolOwnerForeachSegment(CSSymbolOwnerRef owner, void (^block)(CSSegmentRef));
|
||||
+/*
|
||||
CSSymbolOwnerForeachSourceInfo
|
||||
CSSymbolOwnerForeachSymbol
|
||||
*/
|
||||
@@ -333,7 +335,9 @@ void CSSymbolOwnerForeachSymbolWithName(CSSymbolOwnerRef owner, const char *sna
|
||||
/*
|
||||
CSSymbolOwnerGetArchitecture
|
||||
CSSymbolOwnerGetBaseAddress
|
||||
-CSSymbolOwnerGetCFUUIDBytes
|
||||
+*/
|
||||
+const CFUUIDBytes* CSSymbolOwnerGetCFUUIDBytes(CSSymbolOwnerRef owner);
|
||||
+/*
|
||||
CSSymbolOwnerGetCompatibilityVersion
|
||||
CSSymbolOwnerGetCurrentVersion
|
||||
CSSymbolOwnerGetDataFlags
|
||||
@@ -390,7 +394,7 @@ CSSymbolOwnerSetLoadTimestamp
|
||||
CSSymbolOwnerSetPath
|
||||
CSSymbolOwnerSetRelocationCount
|
||||
*/
|
||||
-CSSymbolOwnerSetTransientUserData(CSSymbolOwnerRef owner, uint32_t gen);
|
||||
+void CSSymbolOwnerSetTransientUserData(CSSymbolOwnerRef owner, uint32_t gen);
|
||||
/*
|
||||
CSSymbolOwnerSetUnloadTimestamp
|
||||
*/
|
||||
--
|
||||
2.44.1
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
From ae7ac6a7043dbae8e63d6ce5e63dfaf02b5977fe Mon Sep 17 00:00:00 2001
|
||||
From: Randy Eckenrode <randy@largeandhighquality.com>
|
||||
Date: Sat, 30 Sep 2023 18:37:18 -0400
|
||||
Subject: [PATCH 2/2] Add CF_EXPORT To const symbols
|
||||
|
||||
---
|
||||
CoreSymbolication.h | 15 ++++++++-------
|
||||
1 file changed, 8 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/CoreSymbolication.h b/CoreSymbolication.h
|
||||
index f3cf63f..4124a54 100644
|
||||
--- a/CoreSymbolication.h
|
||||
+++ b/CoreSymbolication.h
|
||||
@@ -49,6 +49,7 @@
|
||||
|
||||
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
+#include <CoreFoundation/CFBase.h>
|
||||
#include <mach/mach.h>
|
||||
|
||||
|
||||
@@ -139,13 +140,13 @@ typedef void (^CSSegmentIterator)(CSSegmentRef segment);
|
||||
* External symbols
|
||||
*/
|
||||
|
||||
-const char* kCSRegionMachHeaderName;
|
||||
-const CSDictionaryKeyCallBacks kCSTypeDictionaryKeyCallBacks;
|
||||
-const CSDictionaryValueCallBacks kCSTypeDictionaryValueCallBacks;
|
||||
-const CSDictionaryKeyCallBacks kCSTypeDictionaryWeakKeyCallBacks;
|
||||
-const CSDictionaryValueCallBacks kCSTypeDictionaryWeakValueCallBacks;
|
||||
-const CSSetCallBacks kCSTypeSetCallBacks;
|
||||
-const CSSetCallBacks kCSTypeSetWeakCallBacks;
|
||||
+CF_EXPORT const char* kCSRegionMachHeaderName;
|
||||
+CF_EXPORT const CSDictionaryKeyCallBacks kCSTypeDictionaryKeyCallBacks;
|
||||
+CF_EXPORT const CSDictionaryValueCallBacks kCSTypeDictionaryValueCallBacks;
|
||||
+CF_EXPORT const CSDictionaryKeyCallBacks kCSTypeDictionaryWeakKeyCallBacks;
|
||||
+CF_EXPORT const CSDictionaryValueCallBacks kCSTypeDictionaryWeakValueCallBacks;
|
||||
+CF_EXPORT const CSSetCallBacks kCSTypeSetCallBacks;
|
||||
+CF_EXPORT const CSSetCallBacks kCSTypeSetWeakCallBacks;
|
||||
|
||||
|
||||
/*
|
||||
--
|
||||
2.44.1
|
||||
|
||||
41
nix/apple-sdk/scripts/get-sdks-from-catalog.sh
Normal file
41
nix/apple-sdk/scripts/get-sdks-from-catalog.sh
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env nix-shell
|
||||
#!nix-shell -i bash -p coreutils curl file gzip jq xcbuild yq
|
||||
|
||||
set -eu -o pipefail
|
||||
|
||||
catalog=${1-}
|
||||
|
||||
if [ -z "$catalog" ]; then
|
||||
echo "usage: get-sdks-from-catalog.sh <catalog>"
|
||||
echo " <catalog> Apple software update catalog (may be gzipped)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
scratch=$(mktemp)
|
||||
trap 'rm -f -- "$scratch"' EXIT
|
||||
|
||||
if [[ "$(file "$catalog")" =~ gzip ]]; then
|
||||
gzcat "$catalog" >"$scratch"
|
||||
else
|
||||
cp --reflink=auto "$catalog" "$scratch"
|
||||
fi
|
||||
|
||||
# Grab all SDK packages from the catalog
|
||||
filter='.Products[].Packages[] | select(.URL | test(".*CLTools_macOSNMOS_SDK.pkg")) | "\(.URL)|\(.MetadataURL)"'
|
||||
|
||||
declare -A package_list
|
||||
for package in $(plutil -convert json -o - "$scratch" | jq -r "$filter"); do
|
||||
package_list[${package%%|*}]=${package#*|}
|
||||
done
|
||||
|
||||
truncate --size 0 "$scratch"
|
||||
for pkg in "${!package_list[@]}"; do
|
||||
ver=$(curl --silent "${package_list[$pkg]}" | xq -r '."pkg-info"."@version"')
|
||||
echo "{\"url\": \"$pkg\", \"version\": \"$(cut -d. -f1-3 <<<"$ver")\", \"long_version\": \"$ver\"}" >>"$scratch"
|
||||
done
|
||||
|
||||
jq -r --slurp '
|
||||
group_by(.version | split(".")[0])
|
||||
| map(max_by(.version))
|
||||
| sort_by(.version)[]
|
||||
| "Package URL: \(.url)\n Xcode Ver: \(.version) (\(.long_version))\n"' "$scratch"
|
||||
70
nix/apple-sdk/scripts/lock-sdk-deps.sh
Normal file
70
nix/apple-sdk/scripts/lock-sdk-deps.sh
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env nix-shell
|
||||
#!nix-shell -i bash -p coreutils curl git gnutar jq moreutils nix
|
||||
|
||||
set -eu -o pipefail
|
||||
|
||||
if [ ! -v 2 ]; then
|
||||
echo "usage: lock-sdk-deps.sh <SDK version> <Packages>" >&2
|
||||
echo " <SDK version> Decimal-separated version number." >&2
|
||||
echo " Must correspond to a tag in https://github.com/apple-oss-distributions/distribution-macOS" >&2
|
||||
echo " <Packages> List of packages from the distributions-macOS repository." >&2
|
||||
echo " Packages not in the repository at the tag for <SDK version> will be ignored."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
pkgdir=$(dirname "$(dirname "$(realpath "$0")")")
|
||||
|
||||
lockfile=$pkgdir/metadata/apple-oss-lockfile.json
|
||||
if [ ! -e "$lockfile" ]; then
|
||||
touch "$lockfile"
|
||||
fi
|
||||
|
||||
workdir=$(mktemp -d)
|
||||
trap 'rm -rf -- "$workdir"' EXIT
|
||||
|
||||
sdkVersion=$1
|
||||
shift
|
||||
tag="macos-${sdkVersion//./}"
|
||||
|
||||
declare -a packages=("$@")
|
||||
|
||||
echo "Locking versions for macOS $sdkVersion using tag '$tag'..."
|
||||
|
||||
pushd "$workdir" >/dev/null
|
||||
|
||||
git clone --branch "$tag" https://github.com/apple-oss-distributions/distribution-macOS.git &>/dev/null
|
||||
cd distribution-macOS
|
||||
|
||||
for package in "${packages[@]}"; do
|
||||
# If the tag exists in `release.json`, use that as an optimization to avoid downloading unnecessarily from Github.
|
||||
packageTag=$(jq -r --arg package "$package" '.projects[] | select(.project == $package) | .tag' release.json)
|
||||
packageCommit=$(git ls-tree -d HEAD "$package" | awk '{print $3}')
|
||||
|
||||
if [ ! -d "$package" ]; then
|
||||
packageCommit=HEAD
|
||||
fi
|
||||
|
||||
# However, sometimes it doesn’t exist. In that case, fall back to cloning the repo and check manually
|
||||
# which tag corresponds to the commit from the submodule.
|
||||
if [ -z "$packageTag" ]; then
|
||||
git clone --no-checkout "https://github.com/apple-oss-distributions/$package.git" ../source &>/dev/null
|
||||
pushd ../source >/dev/null
|
||||
packageTag=$(git tag --points-at "$packageCommit")
|
||||
popd >/dev/null
|
||||
rm -rf ../source
|
||||
fi
|
||||
|
||||
packageVersion=${packageTag##"$package"-}
|
||||
|
||||
curl -OL "https://github.com/apple-oss-distributions/$package/archive/$packageTag.tar.gz" &>/dev/null
|
||||
tar axf "$packageTag.tar.gz"
|
||||
|
||||
packageHash=$(nix --extra-experimental-features nix-command hash path "$package-$packageTag")
|
||||
|
||||
pkgsjson="{\"$sdkVersion\": {\"$package\": {\"version\": \"$packageVersion\", \"hash\": \"$packageHash\"}}}"
|
||||
|
||||
echo " - Locking $package to version $packageVersion with hash '$packageHash'"
|
||||
jq --argjson pkg "$pkgsjson" -S '. * $pkg' "$lockfile" | sponge "$lockfile"
|
||||
done
|
||||
|
||||
popd >/dev/null
|
||||
62
nix/apple-sdk/scripts/regenerate-lockfile.sh
Normal file
62
nix/apple-sdk/scripts/regenerate-lockfile.sh
Normal file
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env nix-shell
|
||||
#!nix-shell -i bash -p coreutils jq
|
||||
|
||||
set -eu -o pipefail
|
||||
|
||||
pkgdir=$(dirname "$(dirname "$(realpath "$0")")")
|
||||
|
||||
echo '{}' >"$pkgdir/metadata/apple-oss-lockfile.json"
|
||||
|
||||
declare -a versions
|
||||
readarray -t versions < <(jq -r '.[].version' "$pkgdir/metadata/versions.json")
|
||||
|
||||
declare -a packages=(
|
||||
CarbonHeaders
|
||||
CommonCrypto
|
||||
IOAudioFamily
|
||||
IOFireWireFamily
|
||||
IOFWDVComponents
|
||||
IOFireWireAVC
|
||||
IOFireWireSBP2
|
||||
IOFireWireSerialBusProtocolTransport
|
||||
IOGraphics
|
||||
IOHIDFamily
|
||||
IONetworkingFamily
|
||||
IOSerialFamily
|
||||
IOStorageFamily
|
||||
IOBDStorageFamily
|
||||
IOCDStorageFamily
|
||||
IODVDStorageFamily
|
||||
IOUSBFamily
|
||||
IOKitUser
|
||||
Libc
|
||||
Libinfo
|
||||
Libm
|
||||
Libnotify
|
||||
Librpcsvc
|
||||
Libsystem
|
||||
OpenDirectory
|
||||
Security
|
||||
architecture
|
||||
configd
|
||||
copyfile
|
||||
dtrace
|
||||
dyld
|
||||
eap8021x
|
||||
hfs
|
||||
launchd
|
||||
libclosure
|
||||
libdispatch
|
||||
libmalloc
|
||||
libplatform
|
||||
libpthread
|
||||
mDNSResponder
|
||||
objc4
|
||||
ppp
|
||||
removefile
|
||||
xnu
|
||||
)
|
||||
|
||||
for version in "${versions[@]}"; do
|
||||
"$pkgdir/scripts/lock-sdk-deps.sh" "$version" "${packages[@]}"
|
||||
done
|
||||
6
nix/apple-sdk/setup-hooks/add-private-frameworks.sh
Normal file
6
nix/apple-sdk/setup-hooks/add-private-frameworks.sh
Normal file
@@ -0,0 +1,6 @@
|
||||
function enablePrivateFrameworks() {
|
||||
export NIX_CFLAGS_COMPILE+=" -iframework $DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/PrivateFrameworks"
|
||||
export NIX_LDFLAGS+=" -F$DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/PrivateFrameworks"
|
||||
}
|
||||
|
||||
preConfigureHooks+=(enablePrivateFrameworks)
|
||||
71
nix/apple-sdk/setup-hooks/role.bash
Normal file
71
nix/apple-sdk/setup-hooks/role.bash
Normal file
@@ -0,0 +1,71 @@
|
||||
# Since the same derivation can be depended on in multiple ways, we need to
|
||||
# accumulate *each* role (i.e. host and target platforms relative the depending
|
||||
# derivation) in which the derivation is used.
|
||||
#
|
||||
# The role is intended to be used as part of other variables names like
|
||||
# - $NIX_SOMETHING${role_post}
|
||||
|
||||
function getRole() {
|
||||
case $1 in
|
||||
-1)
|
||||
role_post='_FOR_BUILD'
|
||||
;;
|
||||
0)
|
||||
role_post=''
|
||||
;;
|
||||
1)
|
||||
role_post='_FOR_TARGET'
|
||||
;;
|
||||
*)
|
||||
echo "@name@: used as improper sort of dependency" >&2
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# `hostOffset` describes how the host platform of the package is slid relative
|
||||
# to the depending package. `targetOffset` likewise describes the target
|
||||
# platform of the package. Both are brought into scope of the setup hook defined
|
||||
# for dependency whose setup hook is being processed relative to the package
|
||||
# being built.
|
||||
|
||||
function getHostRole() {
|
||||
getRole "$hostOffset"
|
||||
}
|
||||
function getTargetRole() {
|
||||
getRole "$targetOffset"
|
||||
}
|
||||
|
||||
# `depHostOffset` describes how the host platform of the dependencies are slid
|
||||
# relative to the depending package. `depTargetOffset` likewise describes the
|
||||
# target platform of dependenices. Both are brought into scope of the
|
||||
# environment hook defined for the dependency being applied relative to the
|
||||
# package being built.
|
||||
|
||||
function getHostRoleEnvHook() {
|
||||
getRole "$depHostOffset"
|
||||
}
|
||||
function getTargetRoleEnvHook() {
|
||||
getRole "$depTargetOffset"
|
||||
}
|
||||
|
||||
# This variant is intended specifically for code-producing tool wrapper scripts
|
||||
# `NIX_@wrapperName@_TARGET_*_@suffixSalt@` tracks this (needs to be an exported
|
||||
# env var so can't use fancier data structures).
|
||||
function getTargetRoleWrapper() {
|
||||
case $targetOffset in
|
||||
-1)
|
||||
export NIX_@wrapperName@_TARGET_BUILD_@suffixSalt@=1
|
||||
;;
|
||||
0)
|
||||
export NIX_@wrapperName@_TARGET_HOST_@suffixSalt@=1
|
||||
;;
|
||||
1)
|
||||
export NIX_@wrapperName@_TARGET_TARGET_@suffixSalt@=1
|
||||
;;
|
||||
*)
|
||||
echo "@name@: used as improper sort of dependency" >&2
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
17
nix/apple-sdk/setup-hooks/sdk-hook.sh
Normal file
17
nix/apple-sdk/setup-hooks/sdk-hook.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
local role_post
|
||||
getHostRole
|
||||
|
||||
local sdkVersionVar=NIX_APPLE_SDK_VERSION${role_post}
|
||||
local developerDirVar=DEVELOPER_DIR${role_post}
|
||||
|
||||
local sdkVersionArr=(@sdkVersion@)
|
||||
local sdkVersion
|
||||
sdkVersion=$(printf "%02d%02d%02d" "${sdkVersionArr[0]-0}" "${sdkVersionArr[1]-0}" "${sdkVersionArr[2]-0}")
|
||||
|
||||
if [ "$sdkVersion" -gt "${!sdkVersionVar-000000}" ]; then
|
||||
export "$developerDirVar"='@out@'
|
||||
export "$sdkVersionVar"="$sdkVersion"
|
||||
export "SDKROOT${role_post}"="${!developerDirVar}/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk"
|
||||
fi
|
||||
|
||||
unset -v role_post developerDirVar sdkVersion sdkVersionArr sdkVersionVar
|
||||
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.4"; in
|
||||
version = let v = "0.30.5"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
@@ -86,6 +86,7 @@ let
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeBool "MLX_BUILD_CPU" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx==0.30.5; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.5; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -31,8 +31,6 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
|
||||
# dependencies only required for development
|
||||
@@ -63,7 +61,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
@@ -105,6 +103,7 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
required-version = ">=0.8.6"
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
|
||||
@@ -59,6 +59,22 @@
|
||||
}
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ exoVenv ];
|
||||
runtimeEnv = {
|
||||
EXO_DASHBOARD_DIR = self'.packages.dashboard;
|
||||
EXO_RESOURCES_DIR = inputs.self + /resources;
|
||||
};
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
|
||||
mkSimplePythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ pkgs.python313 ];
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
@@ -66,28 +82,30 @@
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
|
||||
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
done
|
||||
# Create wrapper script
|
||||
makeWrapper ${exoVenv}/bin/exo $out/bin/exo \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + /resources} \
|
||||
${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin
|
||||
{
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
exo-bench = mkPythonScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
} // {
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 45644286500
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57657697020
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 68899327465
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 89357758772
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 157548627945
|
||||
@@ -16,6 +16,7 @@ from exo.download.download_utils import (
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
@@ -53,11 +54,10 @@ class DownloadCoordinator:
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
@@ -108,6 +108,13 @@ class DownloadCoordinator:
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
case CancelDownload(model_id=model_id):
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads and model_id in self.download_status:
|
||||
logger.info(f"Cancelling download for {model_id}")
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
@@ -158,6 +158,78 @@ async def seed_models(seed_dir: str | Path):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def _build_file_list_from_local_directory(
|
||||
model_id: ModelId,
|
||||
recursive: bool = False,
|
||||
) -> list[FileListEntry] | None:
|
||||
"""Build a file list from locally existing model files.
|
||||
|
||||
We can only figure out the files we need from safetensors index, so
|
||||
a local directory must contain a *.safetensors.index.json and
|
||||
safetensors listed there.
|
||||
"""
|
||||
model_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return None
|
||||
|
||||
def _scan() -> list[FileListEntry] | None:
|
||||
index_files = list(model_dir.glob("**/*.safetensors.index.json"))
|
||||
if not index_files:
|
||||
return None
|
||||
|
||||
entries_by_path: dict[str, FileListEntry] = {}
|
||||
|
||||
if recursive:
|
||||
for dirpath, _, filenames in os.walk(model_dir):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".partial"):
|
||||
continue
|
||||
full_path = Path(dirpath) / filename
|
||||
rel_path = str(full_path.relative_to(model_dir))
|
||||
entries_by_path[rel_path] = FileListEntry(
|
||||
type="file",
|
||||
path=rel_path,
|
||||
size=full_path.stat().st_size,
|
||||
)
|
||||
else:
|
||||
for item in model_dir.iterdir():
|
||||
if item.is_file() and not item.name.endswith(".partial"):
|
||||
entries_by_path[item.name] = FileListEntry(
|
||||
type="file",
|
||||
path=item.name,
|
||||
size=item.stat().st_size,
|
||||
)
|
||||
|
||||
# Add expected weight files from index that haven't been downloaded yet
|
||||
for index_file in index_files:
|
||||
try:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(
|
||||
index_file.read_text()
|
||||
)
|
||||
relative_dir = index_file.parent.relative_to(model_dir)
|
||||
for filename in set(index_data.weight_map.values()):
|
||||
rel_path = (
|
||||
str(relative_dir / filename)
|
||||
if relative_dir != Path(".")
|
||||
else filename
|
||||
)
|
||||
if rel_path not in entries_by_path:
|
||||
entries_by_path[rel_path] = FileListEntry(
|
||||
type="file",
|
||||
path=rel_path,
|
||||
size=None,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return list(entries_by_path.values())
|
||||
|
||||
file_list = await asyncio.to_thread(_scan)
|
||||
if not file_list:
|
||||
return None
|
||||
return file_list
|
||||
|
||||
|
||||
_fetched_file_lists_this_session: set[str] = set()
|
||||
|
||||
|
||||
@@ -183,6 +255,14 @@ async def fetch_file_list_with_cache(
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
local_file_list = await _build_file_list_from_local_directory(
|
||||
model_id, recursive
|
||||
)
|
||||
if local_file_list is not None:
|
||||
logger.warning(
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
)
|
||||
return local_file_list
|
||||
raise FileNotFoundError(
|
||||
f"No internet connection and no cached file list for {model_id}"
|
||||
)
|
||||
@@ -203,10 +283,18 @@ async def fetch_file_list_with_cache(
|
||||
except Exception as e:
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
local_file_list = await _build_file_list_from_local_directory(
|
||||
model_id, recursive
|
||||
)
|
||||
if local_file_list is not None:
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id} and no cache exists, "
|
||||
)
|
||||
return local_file_list
|
||||
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
|
||||
|
||||
|
||||
@@ -378,10 +466,14 @@ async def download_file_with_retry(
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
except Exception as e:
|
||||
on_connection_lost()
|
||||
if attempt == n_attempts - 1:
|
||||
on_connection_lost()
|
||||
raise e
|
||||
break
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
|
||||
@@ -195,6 +195,10 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
_, progress = await download_shard(
|
||||
shard, self.on_progress_wrapper, skip_download=True
|
||||
shard,
|
||||
self.on_progress_wrapper,
|
||||
skip_download=True,
|
||||
skip_internet=not self.internet_connection,
|
||||
on_connection_lost=lambda: self.set_internet_connection(False),
|
||||
)
|
||||
return progress
|
||||
|
||||
@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
@@ -106,6 +105,7 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -136,7 +136,6 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -148,6 +147,8 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
@@ -188,6 +189,9 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
|
||||
@@ -1320,29 +1320,40 @@ class API:
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
tg.start_soon(self.run_api, shutdown_ev)
|
||||
try:
|
||||
await anyio.sleep_forever()
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = None
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
cfg,
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
shutdown_trigger=ev.wait,
|
||||
)
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
cancel_unnecessary_downloads,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
@@ -66,12 +68,9 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
# Ideally these would be WorkerForwarderEvents but type system says no :(
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -81,6 +80,7 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -96,16 +96,18 @@ class Master:
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -278,6 +280,14 @@ class Master:
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
|
||||
@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -202,3 +208,29 @@ def get_transition_events(
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def cancel_unnecessary_downloads(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> Sequence[DownloadCommand]:
|
||||
commands: list[DownloadCommand] = []
|
||||
currently_downloading = [
|
||||
(k, v.shard_metadata.model_card.model_id)
|
||||
for k, vs in download_status.items()
|
||||
for v in vs
|
||||
if isinstance(v, (DownloadOngoing))
|
||||
]
|
||||
active_models = set(
|
||||
(
|
||||
node_id,
|
||||
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
|
||||
)
|
||||
for instance in instances.values()
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
)
|
||||
for pair in currently_downloading:
|
||||
if pair not in active_models:
|
||||
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
|
||||
|
||||
return commands
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -47,6 +48,7 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +69,7 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
download_command_sender=fcds,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -9,6 +9,7 @@ from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -146,18 +147,21 @@ class Router:
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
with move_on_after(1, shield=True):
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -166,12 +170,12 @@ class Router:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
logger.info(f"Subscribed to {topic}")
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
logger.info(f"Unsubscribed from {topic}")
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
|
||||
@@ -86,28 +86,29 @@ class Election:
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election finished")
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
finally:
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election shutdown")
|
||||
|
||||
async def elect(self, em: ElectionMessage) -> None:
|
||||
logger.debug(f"Electing: {em}")
|
||||
|
||||
@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
Command = (
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
|
||||
@@ -194,9 +194,10 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
return await to_thread.run_sync(
|
||||
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
@@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import MiniMaxAttention
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
@@ -503,12 +511,24 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
|
||||
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
# layer.self_attn.kv_b_proj
|
||||
# )
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
@@ -624,6 +644,84 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.group = group
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: "Cache | None" = None,
|
||||
) -> mx.array:
|
||||
batch_dim, seq_dim, _ = x.shape
|
||||
|
||||
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
|
||||
|
||||
queries: mx.array = self._original_layer.q_proj(x)
|
||||
keys: mx.array = self._original_layer.k_proj(x)
|
||||
values: mx.array = self._original_layer.v_proj(x)
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1]
|
||||
k_dim = keys.shape[-1]
|
||||
n = self.group.size()
|
||||
|
||||
qk = mx.concatenate(
|
||||
[queries, keys], axis=-1
|
||||
) # (batch_dim, seq_dim, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (n*batch_dim, seq_dim, q_dim + k_dim)
|
||||
|
||||
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * q_dim)
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * k_dim)
|
||||
|
||||
queries = self._original_layer.q_norm(queries)
|
||||
keys = self._original_layer.k_norm(keys)
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self._original_layer.rope(queries, offset=cache.offset)
|
||||
keys = self._original_layer.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self._original_layer.rope(queries)
|
||||
keys = self._original_layer.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self._original_layer.scale, # type: ignore
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
|
||||
|
||||
return self._original_layer.o_proj(output)
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -632,7 +730,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -643,18 +740,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -679,18 +769,95 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
if isinstance(layer, Qwen3DecoderLayer):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
linear_attn = layer.linear_attn
|
||||
|
||||
linear_attn.in_proj_qkvz = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_qkvz
|
||||
)
|
||||
linear_attn.in_proj_ba = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_ba
|
||||
)
|
||||
linear_attn.out_proj = self.sharded_to_all_linear(
|
||||
linear_attn.out_proj
|
||||
)
|
||||
|
||||
# Shard conv1d: depthwise conv with non-contiguous channel slicing.
|
||||
# Channel layout is [q(key_dim), k(key_dim), v(value_dim)].
|
||||
# Each rank takes its head-slice from each of the three sections.
|
||||
rank = self.group.rank()
|
||||
key_dim = linear_attn.key_dim
|
||||
value_dim = linear_attn.value_dim
|
||||
key_dim_shard = key_dim // self.N
|
||||
value_dim_shard = value_dim // self.N
|
||||
|
||||
q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)
|
||||
k_idx = mx.arange(
|
||||
key_dim + rank * key_dim_shard,
|
||||
key_dim + (rank + 1) * key_dim_shard,
|
||||
)
|
||||
v_idx = mx.arange(
|
||||
2 * key_dim + rank * value_dim_shard,
|
||||
2 * key_dim + (rank + 1) * value_dim_shard,
|
||||
)
|
||||
conv_indices = mx.concatenate([q_idx, k_idx, v_idx])
|
||||
linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]
|
||||
new_conv_dim = key_dim_shard * 2 + value_dim_shard
|
||||
linear_attn.conv1d.groups = new_conv_dim
|
||||
|
||||
num_v_shard = linear_attn.num_v_heads // self.N
|
||||
v_start = rank * num_v_shard
|
||||
v_end = v_start + num_v_shard
|
||||
linear_attn.A_log = linear_attn.A_log[v_start:v_end]
|
||||
linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]
|
||||
|
||||
linear_attn.num_k_heads //= self.N
|
||||
linear_attn.num_v_heads //= self.N
|
||||
linear_attn.key_dim = (
|
||||
linear_attn.head_k_dim * linear_attn.num_k_heads
|
||||
)
|
||||
linear_attn.value_dim = (
|
||||
linear_attn.head_v_dim * linear_attn.num_v_heads
|
||||
)
|
||||
linear_attn.conv_dim = (
|
||||
linear_attn.key_dim * 2 + linear_attn.value_dim
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
@@ -700,6 +867,14 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_expert.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_expert.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -26,51 +24,119 @@ _MEMORY_THRESHOLD = float(
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
class CacheSnapshot:
|
||||
"""Snapshot of states at a known token position."""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
self, states: list[RotatingKVCache | ArraysCache | None], token_count: int
|
||||
):
|
||||
self.states = states
|
||||
self.token_count = token_count
|
||||
|
||||
|
||||
def snapshot_ssm_states(cache: KVCacheType) -> CacheSnapshot:
|
||||
states: list[ArraysCache | RotatingKVCache | None] = []
|
||||
for c in cache:
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
states.append(deepcopy(c))
|
||||
else:
|
||||
states.append(None)
|
||||
token_count = cache_length(cache)
|
||||
return CacheSnapshot(states=states, token_count=token_count)
|
||||
|
||||
|
||||
def _find_nearest_snapshot(
|
||||
snapshots: list[CacheSnapshot],
|
||||
target_token_count: int,
|
||||
) -> CacheSnapshot | None:
|
||||
best: CacheSnapshot | None = None
|
||||
for snap in snapshots:
|
||||
if snap.token_count <= target_token_count and (
|
||||
best is None or snap.token_count > best.token_count
|
||||
):
|
||||
best = snap
|
||||
return best
|
||||
|
||||
|
||||
def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
"""Check if a cache contains any ArraysCache (SSM) entries."""
|
||||
return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
ssm_snapshots: list[CacheSnapshot] | None = None,
|
||||
):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.prompts.append(prompt_tokens)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._snapshots.append(ssm_snapshots)
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
snapshots: list[CacheSnapshot] | None,
|
||||
restore_pos: int,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
old_snapshots = self._snapshots[index]
|
||||
merged: list[CacheSnapshot] = []
|
||||
if old_snapshots:
|
||||
merged = [s for s in old_snapshots if s.token_count <= restore_pos]
|
||||
if snapshots:
|
||||
merged.extend(snapshots)
|
||||
|
||||
self.prompts[index] = prompt_tokens
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._snapshots[index] = merged or None
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
|
||||
|
||||
def _get_snapshot(
|
||||
self, entry_index: int, target_token_count: int
|
||||
) -> tuple[int, CacheSnapshot | None]:
|
||||
if not has_non_kv_caches(self.caches[entry_index]):
|
||||
return target_token_count, None
|
||||
|
||||
snapshots = self._snapshots[entry_index]
|
||||
if not snapshots:
|
||||
return 0, None
|
||||
|
||||
snap = _find_nearest_snapshot(snapshots, target_token_count)
|
||||
if snap is not None:
|
||||
return snap.token_count, snap
|
||||
|
||||
return 0, None
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
@@ -79,76 +145,71 @@ class KVPrefixCache:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
|
||||
For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
|
||||
nearest SSM snapshot position at or before the match point for correctness.
|
||||
Same for rotating KV Cache.
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
max_length = len(prompt_tokens)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
best_index: int | None = None
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
cached_length = cache_length(self.caches[best_index])
|
||||
tokens_to_trim = cached_length - restore_pos
|
||||
if tokens_to_trim > 0:
|
||||
trim_cache(prompt_cache, tokens_to_trim, restore_snap)
|
||||
# Reset cache offset to match trimmed length
|
||||
for c in prompt_cache:
|
||||
if hasattr(c, "offset"):
|
||||
c.offset = restore_pos
|
||||
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
self._access_counter += 1
|
||||
self._last_used[best_index] = self._access_counter
|
||||
remaining = prompt_tokens[restore_pos:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
return prompt_cache, remaining, best_index
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
# Evict LRU entries until below threshold
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
len(self.caches) > 0
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._snapshots.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
@@ -169,6 +230,21 @@ class KVPrefixCache:
|
||||
return max_pressure
|
||||
|
||||
|
||||
def trim_cache(
|
||||
cache: KVCacheType,
|
||||
num_tokens: int,
|
||||
snapshot: CacheSnapshot | None = None,
|
||||
) -> None:
|
||||
for i, c in enumerate(cache):
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
if snapshot is not None and snapshot.states[i] is not None:
|
||||
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
|
||||
else:
|
||||
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
else:
|
||||
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
@@ -177,14 +253,14 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
@@ -215,7 +291,7 @@ def make_kv_cache(
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
if hasattr(model, "make_cache"):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
|
||||
@@ -15,8 +15,3 @@ DEFAULT_TOP_LOGPROBS: int = 5
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
|
||||
# MTP enables speculative decoding using the model's built-in draft layer
|
||||
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
|
||||
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -23,17 +24,23 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
encode_prompt,
|
||||
has_non_kv_caches,
|
||||
make_kv_cache,
|
||||
snapshot_ssm_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
MTP_ENABLED,
|
||||
MTP_NUM_DRAFT_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -49,7 +56,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> tuple[float, int]:
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -60,17 +67,21 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0, 0
|
||||
return 0.0, 0, []
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
has_ssm = has_non_kv_caches(cache)
|
||||
snapshots: list[CacheSnapshot] = []
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
@@ -87,7 +98,18 @@ def prefill(
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
|
||||
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
|
||||
for i, c in enumerate(cache):
|
||||
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
assert pre_gen is not None
|
||||
if pre_gen.states[i] is not None:
|
||||
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
@@ -95,12 +117,14 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec, num_tokens
|
||||
# Exclude the last snapshot
|
||||
return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else []
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -119,7 +143,7 @@ def warmup_inference(
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
sampler = make_sampler(temp=0.0)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
@@ -138,9 +162,7 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated
|
||||
|
||||
@@ -163,11 +185,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def _has_mtp_module(model: Model) -> bool:
|
||||
"""Check if the model has an attached MTP module."""
|
||||
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -228,11 +245,17 @@ def mlx_generate(
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
|
||||
seed = task.seed or 42
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode prompt once at the top and fix unmatched think tags
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
|
||||
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
is_bench = task.bench
|
||||
@@ -244,13 +267,16 @@ def mlx_generate(
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prompt_tokens = all_prompt_tokens
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, all_prompt_tokens
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
if prefix_hit_length > 0:
|
||||
logger.info(
|
||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||
)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -264,24 +290,6 @@ def mlx_generate(
|
||||
top_k=task.top_k if task.top_k is not None else 0,
|
||||
)
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
|
||||
# Check if we should use MTP speculative decoding
|
||||
use_mtp = MTP_ENABLED and _has_mtp_module(model)
|
||||
|
||||
if use_mtp:
|
||||
logger.info("Using MTP speculative decoding")
|
||||
yield from _mlx_generate_with_mtp(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
)
|
||||
return
|
||||
|
||||
# Normalize stop sequences to a list
|
||||
stop_sequences: list[str] = (
|
||||
([task.stop] if isinstance(task.stop, str) else task.stop)
|
||||
@@ -291,13 +299,19 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
last_token = prompt_tokens[-2:]
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
@@ -323,7 +337,6 @@ def mlx_generate(
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
@@ -391,16 +404,6 @@ def mlx_generate(
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
@@ -414,79 +417,44 @@ def mlx_generate(
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
generated_tokens_array = mx.array(
|
||||
tokenizer.encode(
|
||||
"".join(generated_text_parts), add_special_tokens=False
|
||||
)
|
||||
)
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
full_prompt_tokens,
|
||||
caches,
|
||||
cache_snapshots,
|
||||
restore_pos=prefix_hit_length,
|
||||
)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
kv_prefix_cache.add_kv_cache(
|
||||
full_prompt_tokens, caches, cache_snapshots
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
mx_barrier(group)
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
|
||||
def _mlx_generate_with_mtp(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
|
||||
prompt_cache: KVCacheType,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""MTP speculative decoding generation path.
|
||||
|
||||
Uses the model's attached MTP module for speculative decoding,
|
||||
which can provide 1.5-2x speedup with ~81% acceptance rate.
|
||||
"""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
mtp_module: Any = model.mtp_module # type: ignore[attr-defined]
|
||||
|
||||
for out in mtp_speculative_generate(
|
||||
model=model,
|
||||
mtp_module=mtp_module,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=cast(list[Any], prompt_cache),
|
||||
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(f"{out.text} (from_draft={out.from_draft})")
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
__all__ = ["MTPModule", "mtp_speculative_generate"]
|
||||
@@ -1,207 +0,0 @@
|
||||
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
|
||||
|
||||
The MTP architecture predicts one additional token ahead using:
|
||||
1. hnorm - RMSNorm for hidden state normalization
|
||||
2. enorm - RMSNorm for embedding normalization
|
||||
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
|
||||
4. transformer_block - Single decoder layer (attention + MLP)
|
||||
5. Shared embedding/lm_head from main model
|
||||
|
||||
Forward pass:
|
||||
h_norm = hnorm(hidden_state)
|
||||
e_norm = enorm(embed(token))
|
||||
projected = eh_proj(concat([h_norm, e_norm]))
|
||||
new_hidden = transformer_block(projected)
|
||||
logits = lm_head(output_norm(new_hidden))
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
DeepseekV3Attention,
|
||||
DeepseekV3MLP,
|
||||
ModelArgs,
|
||||
)
|
||||
|
||||
MTP_LAYER_INDEX = 61
|
||||
|
||||
|
||||
class MTPModule(nn.Module):
|
||||
"""Multi-Token Prediction module for DeepSeek V3.
|
||||
|
||||
This module is initialized from the layer 61 weights that are normally
|
||||
discarded during model loading. It enables speculative decoding by
|
||||
predicting one token ahead using the hidden state from the main model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
shared_embedding: nn.Embedding,
|
||||
shared_lm_head: nn.Linear,
|
||||
output_norm: nn.RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# MTP-specific normalization layers
|
||||
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Projection: concatenated [hidden, embedding] -> hidden_size
|
||||
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
# Single transformer block for MTP
|
||||
# Use a dense MLP since this is just a single layer
|
||||
self.transformer_block = MTPTransformerBlock(config)
|
||||
|
||||
# Share embedding and lm_head with main model
|
||||
self._shared_embedding = shared_embedding
|
||||
self._shared_lm_head = shared_lm_head
|
||||
self._output_norm = output_norm
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
cache: KVCache | None = None,
|
||||
mask: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass for MTP.
|
||||
|
||||
Args:
|
||||
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
|
||||
draft_token: Token to embed and combine with hidden state [batch, seq_len]
|
||||
cache: Optional KV cache for the MTP transformer block
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
tuple of (logits, new_hidden_state)
|
||||
"""
|
||||
# Get embedding of draft token
|
||||
embedding = self._shared_embedding(draft_token)
|
||||
|
||||
# Normalize hidden state and embedding
|
||||
h_norm = self.hnorm(hidden_state)
|
||||
e_norm = self.enorm(embedding)
|
||||
|
||||
# Project concatenated representation
|
||||
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
|
||||
projected = self.eh_proj(concatenated)
|
||||
|
||||
# Pass through single transformer block
|
||||
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
|
||||
|
||||
# Apply output norm and get logits
|
||||
normed_hidden = self._output_norm(new_hidden)
|
||||
logits = self._shared_lm_head(normed_hidden)
|
||||
|
||||
return logits, new_hidden
|
||||
|
||||
|
||||
class MTPTransformerBlock(nn.Module):
|
||||
"""Single transformer block for MTP.
|
||||
|
||||
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
|
||||
instead of MoE since this is just for the single MTP layer.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
# MTP uses dense MLP, not MoE
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
"""Forward pass with residual connections."""
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
|
||||
"""Extract MTP-specific weights from layer 61.
|
||||
|
||||
The MTP layer has these weight patterns:
|
||||
- model.layers.61.enorm.weight -> MTP embedding normalization
|
||||
- model.layers.61.hnorm.weight -> MTP hidden normalization
|
||||
- model.layers.61.eh_proj.weight -> MTP projection layer
|
||||
- model.layers.61.self_attn.* -> MTP attention
|
||||
- model.layers.61.input_layernorm.* -> MTP layer norms
|
||||
- model.layers.61.post_attention_layernorm.*
|
||||
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
|
||||
|
||||
Args:
|
||||
weights: Full model weights dict
|
||||
|
||||
Returns:
|
||||
Dict of MTP-specific weights with keys renamed for MTPModule
|
||||
"""
|
||||
mtp_weights: dict[str, mx.array] = {}
|
||||
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mtp_prefix):
|
||||
# Remove the layer prefix to get relative path
|
||||
new_key = key[len(mtp_prefix) :]
|
||||
mtp_weights[new_key] = value
|
||||
|
||||
return mtp_weights
|
||||
|
||||
|
||||
def load_mtp_weights_into_module(
|
||||
mtp_module: MTPModule,
|
||||
mtp_weights: dict[str, mx.array],
|
||||
) -> None:
|
||||
"""Load extracted MTP weights into the MTPModule.
|
||||
|
||||
Args:
|
||||
mtp_module: The MTPModule instance to load weights into
|
||||
mtp_weights: Extracted MTP weights from extract_mtp_weights()
|
||||
"""
|
||||
# Map weight names to module attributes
|
||||
weight_mapping: dict[str, str] = {
|
||||
"enorm.weight": "enorm.weight",
|
||||
"hnorm.weight": "hnorm.weight",
|
||||
"eh_proj.weight": "eh_proj.weight",
|
||||
}
|
||||
|
||||
# Load direct mappings
|
||||
for src_name, dst_name in weight_mapping.items():
|
||||
if src_name in mtp_weights:
|
||||
parts = dst_name.split(".")
|
||||
obj: Any = mtp_module
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], mtp_weights[src_name])
|
||||
|
||||
# Load transformer block weights (self_attn, mlp, layer norms)
|
||||
transformer_prefixes = [
|
||||
"self_attn",
|
||||
"mlp",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
]
|
||||
|
||||
for prefix in transformer_prefixes:
|
||||
for key, value in mtp_weights.items():
|
||||
if key.startswith(prefix):
|
||||
# Navigate to the correct attribute
|
||||
parts = key.split(".")
|
||||
obj = mtp_module.transformer_block
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], value)
|
||||
@@ -1,506 +0,0 @@
|
||||
"""MTP Speculative Decoding for DeepSeek V3.
|
||||
|
||||
This module implements speculative decoding using the Multi-Token Prediction (MTP)
|
||||
layer from DeepSeek V3. The key difference from standard speculative decoding is
|
||||
that MTP requires hidden states from the main model, not just token predictions.
|
||||
|
||||
Based on vLLM/SGLang research:
|
||||
- 81-82% acceptance rate with k=1
|
||||
- 1.5-2x speedup at low QPS
|
||||
"""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models import cache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Generation stream for async operations
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTPGenerationResponse:
|
||||
"""Response from MTP speculative generation.
|
||||
|
||||
Attributes:
|
||||
text: The next segment of decoded text.
|
||||
token: The next token.
|
||||
logprobs: A vector of log probabilities.
|
||||
from_draft: Whether the token was generated by the MTP draft module.
|
||||
prompt_tokens: The number of tokens in the prompt.
|
||||
prompt_tps: The prompt processing tokens-per-second.
|
||||
generation_tokens: The number of generated tokens.
|
||||
generation_tps: The tokens-per-second for generation.
|
||||
peak_memory: The peak memory used so far in GB.
|
||||
finish_reason: The reason the response is being sent: "length", "stop" or None.
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
from_draft: bool
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
"""Quantize KV cache entries if needed."""
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized")
|
||||
and hasattr(c, "offset")
|
||||
and c.offset >= quantized_kv_start
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
|
||||
class ModelWithHiddenStates(nn.Module):
|
||||
"""Wrapper to extract hidden states before lm_head.
|
||||
|
||||
This wrapper allows capturing the hidden states from the transformer
|
||||
layers before the final lm_head projection, which is needed for MTP.
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self._base = base_model
|
||||
|
||||
def forward_with_hidden(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass that returns both logits and hidden states.
|
||||
|
||||
Args:
|
||||
inputs: Input token ids
|
||||
model_cache: KV cache
|
||||
|
||||
Returns:
|
||||
Tuple of (logits, hidden_states)
|
||||
"""
|
||||
# Call the inner model (transformer layers + norm)
|
||||
hidden: mx.array = self._base.model(inputs, model_cache)
|
||||
# Get logits from lm_head
|
||||
logits: mx.array = self._base.lm_head(hidden)
|
||||
return logits, hidden
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> mx.array:
|
||||
"""Standard forward pass returning only logits."""
|
||||
return cast(mx.array, self._base(inputs, cache=model_cache))
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
"""Access layers for cache creation."""
|
||||
return cast(list[nn.Module], self._base.layers)
|
||||
|
||||
|
||||
def mtp_speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
*,
|
||||
num_draft_tokens: int = 1,
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
mtp_cache: KVCache | None = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: int | None = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[tuple[int, mx.array, bool], None, None]:
|
||||
"""MTP speculative decoding generator.
|
||||
|
||||
Unlike standard speculative decoding where the draft model only needs tokens,
|
||||
MTP requires the hidden states from the main model. This generator:
|
||||
|
||||
1. Runs the main model to get logits AND hidden states
|
||||
2. Uses MTP module with hidden state + sampled token to predict next token
|
||||
3. Verifies MTP predictions with the main model
|
||||
4. Accepts/rejects based on matching
|
||||
|
||||
Args:
|
||||
prompt: The input prompt as token ids
|
||||
model: The main model (must support return_hidden=True)
|
||||
mtp_module: The MTP module for draft prediction
|
||||
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
sampler: Optional sampler function for token selection
|
||||
logits_processors: Optional list of logits processors
|
||||
prompt_cache: KV cache for the main model
|
||||
mtp_cache: KV cache for the MTP module
|
||||
prefill_step_size: Step size for prompt processing
|
||||
kv_bits: Bits for KV cache quantization
|
||||
kv_group_size: Group size for KV cache quantization
|
||||
quantized_kv_start: Step to begin cache quantization
|
||||
|
||||
Yields:
|
||||
Tuple of (token, logprobs, from_draft)
|
||||
"""
|
||||
y = prompt.astype(mx.uint32)
|
||||
prev_tokens: mx.array | None = None
|
||||
|
||||
# Wrap model to get hidden states
|
||||
wrapped_model = (
|
||||
model
|
||||
if isinstance(model, ModelWithHiddenStates)
|
||||
else ModelWithHiddenStates(model)
|
||||
)
|
||||
|
||||
# Create caches if needed
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
if mtp_cache is None:
|
||||
mtp_cache = KVCache()
|
||||
|
||||
final_sampler = (
|
||||
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
|
||||
)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _process_and_sample(
|
||||
tokens: mx.array | None,
|
||||
logits: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Process logits and sample tokens."""
|
||||
nonlocal logits_processors
|
||||
processed_logits = logits
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
processed_logits = processor(
|
||||
tokens if tokens is not None else mx.array([]), processed_logits
|
||||
)
|
||||
|
||||
logprobs = processed_logits - mx.logsumexp(
|
||||
processed_logits, axis=-1, keepdims=True
|
||||
)
|
||||
sampled = final_sampler(logprobs)
|
||||
return sampled, logprobs
|
||||
|
||||
def _main_model_step_with_hidden(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Run main model step with hidden state return."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits, hidden = wrapped_model.forward_with_hidden(
|
||||
input_y[None], prompt_cache
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
|
||||
|
||||
def _main_model_step(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run main model step without hidden state."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits = wrapped_model.forward(input_y[None], prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0)
|
||||
|
||||
def _mtp_draft(
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Generate draft token using MTP module."""
|
||||
with mx.stream(generation_stream):
|
||||
logits, new_hidden = mtp_module(
|
||||
hidden_state,
|
||||
draft_token,
|
||||
cache=mtp_cache,
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
sampled, _ = _process_and_sample(None, logits)
|
||||
return sampled, new_hidden
|
||||
|
||||
def _prefill(input_y: mx.array) -> mx.array:
|
||||
"""Prefill the prompt cache."""
|
||||
result_y = input_y
|
||||
while result_y.size > prefill_step_size:
|
||||
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
result_y = result_y[prefill_step_size:]
|
||||
mx.clear_cache()
|
||||
return result_y
|
||||
|
||||
def _rewind_cache(num_draft: int, num_accept: int) -> None:
|
||||
"""Rewind caches after rejection."""
|
||||
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
|
||||
|
||||
# Prefill phase
|
||||
with mx.stream(generation_stream):
|
||||
y = _prefill(y)
|
||||
|
||||
ntoks = 0
|
||||
num_draft = 0
|
||||
n_accepted = 0
|
||||
last_hidden: mx.array | None = None
|
||||
|
||||
try:
|
||||
# Initial step to get first token and hidden state
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
while ntoks < max_tokens:
|
||||
# Draft phase: use MTP to predict next token
|
||||
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
|
||||
|
||||
if num_draft > 0 and last_hidden is not None:
|
||||
# Use MTP to draft
|
||||
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
|
||||
mx.eval(draft_token, draft_hidden)
|
||||
|
||||
# Verify with main model
|
||||
# Feed the drafted token to main model
|
||||
verify_input = mx.concatenate([y, draft_token.flatten()])
|
||||
verify_sampled, verify_logprobs, new_hidden = (
|
||||
_main_model_step_with_hidden(verify_input)
|
||||
)
|
||||
mx.eval(verify_sampled, verify_logprobs, new_hidden)
|
||||
|
||||
# Check if draft matches verification
|
||||
draft_token_val = int(draft_token.item())
|
||||
verify_token_val = (
|
||||
int(verify_sampled[0].item())
|
||||
if verify_sampled.shape[0] > 1
|
||||
else int(verify_sampled.item())
|
||||
)
|
||||
|
||||
# Yield the current token (not from draft)
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
if draft_token_val == verify_token_val:
|
||||
# Draft accepted
|
||||
n_accepted += 1
|
||||
ntoks += 1
|
||||
draft_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
yield draft_token_val, draft_logprobs, True
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
# Continue with the token after the draft
|
||||
y = (
|
||||
verify_sampled[-1:]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[-1]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = new_hidden
|
||||
else:
|
||||
# Draft rejected - rewind and use verified token
|
||||
_rewind_cache(1, 0)
|
||||
y = (
|
||||
verify_sampled[:1]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = (
|
||||
new_hidden[:, :1, :] if new_hidden is not None else None
|
||||
)
|
||||
else:
|
||||
# No drafting, just do normal generation
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
if ntoks % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
finally:
|
||||
_rewind_cache(num_draft, n_accepted)
|
||||
|
||||
|
||||
def mtp_speculative_generate(
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str | mx.array | list[int],
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
num_draft_tokens: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int | None = None,
|
||||
) -> Generator[MTPGenerationResponse, None, None]:
|
||||
"""High-level MTP speculative generation with text output.
|
||||
|
||||
Args:
|
||||
model: The main model
|
||||
mtp_module: The MTP module for draft prediction
|
||||
tokenizer: Tokenizer for encoding/decoding
|
||||
prompt: Input prompt (string, array, or token list)
|
||||
max_tokens: Maximum tokens to generate
|
||||
sampler: Optional sampler function
|
||||
logits_processors: Optional logits processors
|
||||
prompt_cache: Optional KV cache
|
||||
num_draft_tokens: Number of draft tokens
|
||||
prefill_step_size: Prefill step size
|
||||
kv_group_size: KV group size
|
||||
kv_bits: KV bits
|
||||
|
||||
Yields:
|
||||
MTPGenerationResponse objects with text and metadata
|
||||
"""
|
||||
if not isinstance(prompt, mx.array):
|
||||
if isinstance(prompt, str):
|
||||
bos_token = getattr(tokenizer, "bos_token", None)
|
||||
add_special_tokens = bos_token is None or not prompt.startswith(
|
||||
str(bos_token)
|
||||
)
|
||||
encoded: list[int] = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
prompt = mx.array(encoded)
|
||||
else:
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
|
||||
|
||||
token_generator = mtp_speculative_generate_step(
|
||||
prompt,
|
||||
model,
|
||||
mtp_module,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
prompt_tps = 0.0
|
||||
token = 0
|
||||
logprobs: mx.array = mx.array([0.0])
|
||||
from_draft = False
|
||||
n = 0
|
||||
|
||||
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = float(prompt.size) / prompt_time
|
||||
tic = time.perf_counter()
|
||||
|
||||
if token in eos_token_ids:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token)
|
||||
if (n + 1) == max_tokens:
|
||||
break
|
||||
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment),
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
detokenizer.finalize()
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment),
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason="stop" if token in eos_token_ids else "length",
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for MTP module."""
|
||||
@@ -1,412 +0,0 @@
|
||||
"""Unit tests for MTP module components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTP_LAYER_INDEX,
|
||||
MTPModule,
|
||||
MTPTransformerBlock,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
|
||||
class MockModelArgs:
|
||||
"""Mock ModelArgs for testing without importing deepseek_v3."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 256,
|
||||
intermediate_size: int = 512,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
vocab_size: int = 1000,
|
||||
q_lora_rank: int | None = None,
|
||||
kv_lora_rank: int = 64,
|
||||
qk_rope_head_dim: int = 16,
|
||||
v_head_dim: int = 32,
|
||||
qk_nope_head_dim: int = 32,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling: dict | None = None,
|
||||
attention_bias: bool = False,
|
||||
max_position_embeddings: int = 2048,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
|
||||
class TestExtractMTPWeights:
|
||||
"""Tests for extract_mtp_weights function."""
|
||||
|
||||
def test_extracts_layer_61_weights(self) -> None:
|
||||
"""Should extract only layer 61 weights."""
|
||||
weights = {
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.61.enorm.weight": mx.ones((10,)),
|
||||
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
|
||||
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
|
||||
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.embed_tokens.weight": mx.zeros((100, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 3
|
||||
assert "enorm.weight" in mtp_weights
|
||||
assert "hnorm.weight" in mtp_weights
|
||||
assert "eh_proj.weight" in mtp_weights
|
||||
# Check values are preserved
|
||||
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
|
||||
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
|
||||
|
||||
def test_returns_empty_dict_when_no_layer_61(self) -> None:
|
||||
"""Should return empty dict when layer 61 doesn't exist."""
|
||||
weights = {
|
||||
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 0
|
||||
|
||||
def test_handles_nested_layer_61_weights(self) -> None:
|
||||
"""Should handle nested weight paths like self_attn.q_proj.weight."""
|
||||
weights = {
|
||||
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
|
||||
(10, 10)
|
||||
),
|
||||
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert "self_attn.q_a_proj.weight" in mtp_weights
|
||||
assert "mlp.gate_proj.weight" in mtp_weights
|
||||
|
||||
|
||||
class TestMTPTransformerBlock:
|
||||
"""Tests for MTPTransformerBlock."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64, intermediate_size=128, num_attention_heads=2
|
||||
)
|
||||
|
||||
def test_forward_shape(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should preserve input shape."""
|
||||
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
|
||||
output = block(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
def test_forward_with_mask(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should work with attention mask."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
# Create causal mask
|
||||
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
|
||||
|
||||
output = block(x, mask=mask)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
|
||||
class TestMTPModule:
|
||||
"""Tests for MTPModule."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def shared_components(
|
||||
self, config: MockModelArgs
|
||||
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
return embedding, lm_head, output_norm
|
||||
|
||||
def test_initialization(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should initialize with correct components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
assert mtp.hnorm is not None
|
||||
assert mtp.enorm is not None
|
||||
assert mtp.eh_proj is not None
|
||||
assert mtp.transformer_block is not None
|
||||
|
||||
def test_forward_output_shapes(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""Forward pass should return correct output shapes."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 1
|
||||
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
|
||||
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
|
||||
|
||||
logits, new_hidden = mtp(hidden_state, draft_token)
|
||||
|
||||
assert logits.shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
def test_shares_embedding_and_lm_head(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should use shared embedding and lm_head."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Verify they're the same objects
|
||||
assert mtp._shared_embedding is embedding
|
||||
assert mtp._shared_lm_head is lm_head
|
||||
assert mtp._output_norm is output_norm
|
||||
|
||||
|
||||
class TestLoadMTPWeights:
|
||||
"""Tests for load_mtp_weights_into_module."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
|
||||
"""Should load enorm and hnorm weights."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Create test weights
|
||||
test_enorm = mx.ones((config.hidden_size,)) * 3.0
|
||||
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
|
||||
mtp_weights = {
|
||||
"enorm.weight": test_enorm,
|
||||
"hnorm.weight": test_hnorm,
|
||||
}
|
||||
|
||||
load_mtp_weights_into_module(mtp, mtp_weights)
|
||||
|
||||
assert mx.allclose(mtp.enorm.weight, test_enorm)
|
||||
assert mx.allclose(mtp.hnorm.weight, test_hnorm)
|
||||
|
||||
|
||||
class TestSanitizePatch:
|
||||
"""Tests for the sanitize patching logic."""
|
||||
|
||||
def test_patch_preserves_layer_61(self) -> None:
|
||||
"""Patching sanitize should preserve layer 61 weights."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
# Get original sanitize behavior
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
try:
|
||||
# Apply patch
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
# Note: we can't easily test the full sanitize without a real model
|
||||
# This test verifies the patch is applied
|
||||
assert model_cls.sanitize is not original_sanitize
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
# Verify restore worked
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
def test_restore_sanitize(self) -> None:
|
||||
"""Restoring sanitize should return to original behavior."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is not original_sanitize
|
||||
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
def test_double_patch_is_safe(self) -> None:
|
||||
"""Calling patch twice should be safe (idempotent)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
try:
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
patched_sanitize = model_cls.sanitize
|
||||
|
||||
# Patch again - should be no-op
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is patched_sanitize
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
|
||||
class TestModelIdDetection:
|
||||
"""Tests for DeepSeek V3 model ID detection."""
|
||||
|
||||
def test_detects_deepseek_v3(self) -> None:
|
||||
"""Should detect DeepSeek V3 model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
|
||||
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
|
||||
|
||||
def test_detects_deepseek_r1(self) -> None:
|
||||
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
|
||||
|
||||
def test_rejects_non_deepseek(self) -> None:
|
||||
"""Should reject non-DeepSeek model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
|
||||
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
|
||||
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
"""Detection should be case insensitive."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
|
||||
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
|
||||
|
||||
|
||||
class TestFlattenParams:
|
||||
"""Tests for parameter flattening utility."""
|
||||
|
||||
def test_flattens_nested_dict(self) -> None:
|
||||
"""Should flatten nested parameter dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _flatten_params
|
||||
|
||||
params = {
|
||||
"model": {
|
||||
"layers": {
|
||||
"0": {
|
||||
"weight": mx.zeros((10,)),
|
||||
}
|
||||
},
|
||||
"embed": mx.ones((5,)),
|
||||
}
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert "model.layers.0.weight" in flat
|
||||
assert "model.embed" in flat
|
||||
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
|
||||
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
|
||||
|
||||
def test_handles_flat_dict(self) -> None:
|
||||
"""Should handle already-flat dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _flatten_params
|
||||
|
||||
params = {
|
||||
"weight": mx.zeros((10,)),
|
||||
"bias": mx.ones((10,)),
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert flat == params
|
||||
@@ -1,253 +0,0 @@
|
||||
"""Unit tests for MTP speculative decoding."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import (
|
||||
ModelWithHiddenStates,
|
||||
maybe_quantize_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
class MockModel(nn.Module):
|
||||
"""Mock model for testing speculative decoding."""
|
||||
|
||||
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Create simple model components
|
||||
self.model = MockInnerModel(hidden_size)
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list | None = None,
|
||||
) -> mx.array:
|
||||
hidden = self.model(inputs, cache)
|
||||
return self.lm_head(hidden)
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
return self._layers
|
||||
|
||||
|
||||
class MockInnerModel(nn.Module):
|
||||
"""Mock inner model (like DeepseekV3Model)."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(100, hidden_size)
|
||||
self.norm = nn.RMSNorm(hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list | None = None,
|
||||
) -> mx.array:
|
||||
# Simple embedding + norm
|
||||
embedded = self.embed_tokens(inputs)
|
||||
return self.norm(embedded)
|
||||
|
||||
|
||||
class TestModelWithHiddenStates:
|
||||
"""Tests for ModelWithHiddenStates wrapper."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self) -> MockModel:
|
||||
return MockModel(hidden_size=64, vocab_size=100)
|
||||
|
||||
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
|
||||
"""Standard forward should return logits."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits = wrapped.forward(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
|
||||
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
|
||||
"""Forward with hidden should return (logits, hidden)."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits, hidden = wrapped.forward_with_hidden(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
assert hidden.shape == (1, 3, mock_model.hidden_size)
|
||||
|
||||
def test_layers_property(self, mock_model: MockModel) -> None:
|
||||
"""Should expose layers property from base model."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
|
||||
assert wrapped.layers == mock_model.layers
|
||||
assert len(wrapped.layers) == 3
|
||||
|
||||
|
||||
class TestMaybeQuantizeKVCache:
|
||||
"""Tests for KV cache quantization."""
|
||||
|
||||
def test_no_quantization_when_bits_none(self) -> None:
|
||||
"""Should not quantize when kv_bits is None."""
|
||||
cache = [MockCache(offset=100)]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
cache,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=None,
|
||||
)
|
||||
|
||||
# Cache should be unchanged
|
||||
assert not hasattr(cache[0], "quantized")
|
||||
|
||||
def test_respects_quantized_kv_start(self) -> None:
|
||||
"""Should only quantize caches past the start threshold."""
|
||||
cache_below = MockCache(offset=30)
|
||||
cache_above = MockCache(offset=100)
|
||||
caches = [cache_below, cache_above]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
caches,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=4,
|
||||
)
|
||||
|
||||
# Only cache_above should be quantized
|
||||
assert not getattr(cache_below, "was_quantized", False)
|
||||
assert getattr(caches[1], "was_quantized", False)
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Mock KV cache for testing."""
|
||||
|
||||
def __init__(self, offset: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.was_quantized = False
|
||||
|
||||
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
|
||||
quantized = MockCache(self.offset)
|
||||
quantized.was_quantized = True
|
||||
return quantized
|
||||
|
||||
|
||||
class TestSpeculativeDecodingLogic:
|
||||
"""Tests for the core speculative decoding logic."""
|
||||
|
||||
def test_draft_acceptance_identical_tokens(self) -> None:
|
||||
"""When draft matches verification, both should be accepted."""
|
||||
# This tests the logic, not the full generator
|
||||
draft_token = 42
|
||||
verify_token = 42
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert accepted
|
||||
|
||||
def test_draft_rejection_different_tokens(self) -> None:
|
||||
"""When draft differs from verification, draft should be rejected."""
|
||||
draft_token = 42
|
||||
verify_token = 99
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert not accepted
|
||||
|
||||
|
||||
class TestMTPGenerationResponse:
|
||||
"""Tests for MTPGenerationResponse dataclass."""
|
||||
|
||||
def test_response_creation(self) -> None:
|
||||
"""Should create response with all fields."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="Hello",
|
||||
token=42,
|
||||
logprobs=mx.array([0.1, 0.2]),
|
||||
from_draft=True,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=5,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert response.text == "Hello"
|
||||
assert response.token == 42
|
||||
assert response.from_draft is True
|
||||
assert response.finish_reason is None
|
||||
|
||||
def test_response_with_finish_reason(self) -> None:
|
||||
"""Should handle finish_reason."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="",
|
||||
token=0,
|
||||
logprobs=mx.array([0.0]),
|
||||
from_draft=False,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=100,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason="length",
|
||||
)
|
||||
|
||||
assert response.finish_reason == "length"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full MTP pipeline."""
|
||||
|
||||
def test_mtp_module_with_mock_model(self) -> None:
|
||||
"""Test MTP module can be created and run with mock components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Create mock config
|
||||
class MockConfig:
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_attention_heads = 2
|
||||
num_key_value_heads = 2
|
||||
rms_norm_eps = 1e-6
|
||||
q_lora_rank = None
|
||||
kv_lora_rank = 32
|
||||
qk_rope_head_dim = 8
|
||||
v_head_dim = 16
|
||||
qk_nope_head_dim = 16
|
||||
rope_theta = 10000.0
|
||||
rope_scaling = None
|
||||
attention_bias = False
|
||||
max_position_embeddings = 2048
|
||||
|
||||
config = MockConfig()
|
||||
embedding = nn.Embedding(100, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
hidden = mx.random.normal((1, 1, config.hidden_size))
|
||||
token = mx.array([[5]])
|
||||
|
||||
logits, new_hidden = mtp(hidden, token)
|
||||
|
||||
assert logits.shape == (1, 1, 100)
|
||||
assert new_hidden.shape == (1, 1, config.hidden_size)
|
||||
# Verify outputs are valid (not NaN)
|
||||
assert not mx.any(mx.isnan(logits))
|
||||
assert not mx.any(mx.isnan(new_hidden))
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import resource
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -25,7 +24,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
MTP_ENABLED,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -71,142 +69,6 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
# MTP (Multi-Token Prediction) support for DeepSeek V3
|
||||
MTP_LAYER_INDEX = 61
|
||||
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def _is_deepseek_v3_model(model: nn.Module) -> bool:
|
||||
"""Check if the model is DeepSeek V3."""
|
||||
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model)
|
||||
|
||||
|
||||
def _might_be_deepseek_v3(model_id: str) -> bool:
|
||||
"""Check if model ID suggests this might be DeepSeek V3."""
|
||||
model_id_lower = model_id.lower()
|
||||
return "deepseek" in model_id_lower and (
|
||||
"v3" in model_id_lower or "r1" in model_id_lower
|
||||
)
|
||||
|
||||
|
||||
def _patch_deepseek_sanitize_for_mtp() -> None:
|
||||
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights."""
|
||||
global _original_deepseek_sanitize
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
if _original_deepseek_sanitize is not None:
|
||||
return
|
||||
|
||||
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
|
||||
|
||||
def sanitize_with_mtp(
|
||||
self: DeepSeekV3Model, weights: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if _original_deepseek_sanitize is None:
|
||||
raise RuntimeError(
|
||||
"_original_deepseek_sanitize is None - patch not applied correctly"
|
||||
)
|
||||
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
|
||||
mtp_weights = {
|
||||
k: v
|
||||
for k, v in weights.items()
|
||||
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
|
||||
}
|
||||
return {**original_result, **mtp_weights}
|
||||
|
||||
DeepSeekV3Model.sanitize = sanitize_with_mtp
|
||||
|
||||
|
||||
def _restore_deepseek_sanitize() -> None:
|
||||
"""Restore the original DeepSeek V3 sanitize method."""
|
||||
global _original_deepseek_sanitize
|
||||
if _original_deepseek_sanitize is None:
|
||||
return
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
|
||||
_original_deepseek_sanitize = None
|
||||
|
||||
|
||||
def _flatten_params(
|
||||
params: dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> dict[str, mx.array]:
|
||||
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
|
||||
result: dict[str, mx.array] = {}
|
||||
for key, value in params.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, mx.array):
|
||||
result[full_key] = value
|
||||
elif isinstance(value, dict):
|
||||
result.update(_flatten_params(value, full_key))
|
||||
return result
|
||||
|
||||
|
||||
def _extract_mtp_module(model: nn.Module) -> Any | None:
|
||||
"""Extract MTP module from a loaded DeepSeek V3 model."""
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTPModule,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
try:
|
||||
inner_model = getattr(model, "model", None)
|
||||
if inner_model is None or not hasattr(inner_model, "layers"):
|
||||
logger.debug("Model doesn't have expected structure for MTP extraction")
|
||||
return None
|
||||
|
||||
layers: list[nn.Module] = inner_model.layers # type: ignore[assignment]
|
||||
if len(layers) <= MTP_LAYER_INDEX:
|
||||
logger.debug(
|
||||
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
|
||||
)
|
||||
return None
|
||||
|
||||
config = getattr(model, "args", None)
|
||||
if config is None:
|
||||
logger.debug("Could not get model config for MTP module")
|
||||
return None
|
||||
|
||||
embed_tokens = getattr(inner_model, "embed_tokens", None)
|
||||
lm_head = getattr(model, "lm_head", None)
|
||||
norm = getattr(inner_model, "norm", None)
|
||||
|
||||
if embed_tokens is None or lm_head is None or norm is None:
|
||||
logger.debug("Could not get required model components for MTP")
|
||||
return None
|
||||
|
||||
mtp_module = MTPModule(
|
||||
config=config,
|
||||
shared_embedding=embed_tokens,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=norm,
|
||||
)
|
||||
|
||||
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
|
||||
model_weights = _flatten_params(raw_params)
|
||||
mtp_weights = extract_mtp_weights(model_weights)
|
||||
|
||||
if not mtp_weights:
|
||||
logger.debug("No MTP weights found in model parameters")
|
||||
return None
|
||||
|
||||
load_mtp_weights_into_module(mtp_module, mtp_weights)
|
||||
|
||||
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
|
||||
inner_model.layers = new_layers # noqa: B010
|
||||
|
||||
logger.info(
|
||||
f"Extracted MTP module, main model now has {len(new_layers)} layers"
|
||||
)
|
||||
return mtp_module
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract MTP module: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -339,52 +201,28 @@ def load_mlx_items(
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
model_id = bound_instance.bound_shard.model_card.model_id
|
||||
mtp_module = None
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
# Patch sanitize for MTP if this might be DeepSeek V3
|
||||
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
|
||||
if should_try_mtp:
|
||||
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
try:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=not should_try_mtp)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
# Extract MTP module if available
|
||||
if should_try_mtp and _is_deepseek_v3_model(model):
|
||||
mtp_module = _extract_mtp_module(model)
|
||||
if mtp_module is not None:
|
||||
logger.info("Successfully extracted MTP module from DeepSeek V3")
|
||||
|
||||
finally:
|
||||
if should_try_mtp:
|
||||
_restore_deepseek_sanitize()
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
# Store MTP module on the model for later access
|
||||
if mtp_module is not None:
|
||||
model.mtp_module = mtp_module # noqa: B010
|
||||
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
@@ -652,6 +490,30 @@ def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> b
|
||||
return think_token is not None and prompt.rstrip().endswith(think_token)
|
||||
|
||||
|
||||
def fix_unmatched_think_end_tokens(
|
||||
tokens: mx.array, tokenizer: TokenizerWrapper
|
||||
) -> mx.array:
|
||||
if not tokenizer.has_thinking:
|
||||
return tokens
|
||||
assert tokenizer.think_start_id
|
||||
assert tokenizer.think_end_id
|
||||
think_start_id: int = tokenizer.think_start_id
|
||||
think_end_id: int = tokenizer.think_end_id
|
||||
token_list: list[int] = cast(list[int], tokens.tolist())
|
||||
result: list[int] = []
|
||||
depth = 0
|
||||
for token in token_list:
|
||||
if token == think_start_id:
|
||||
depth += 1
|
||||
elif token == think_end_id:
|
||||
if depth == 0:
|
||||
result.append(think_start_id)
|
||||
else:
|
||||
depth -= 1
|
||||
result.append(token)
|
||||
return mx.array(result)
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
"""
|
||||
A KVCache that pretends to exist but holds zero tokens.
|
||||
|
||||
@@ -98,21 +98,23 @@ class Worker:
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
finally:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
logger.info("Stopping Worker")
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
|
||||
@@ -193,7 +193,7 @@ def main(
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
@@ -226,6 +226,7 @@ def main(
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
@@ -274,6 +275,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
@@ -627,7 +629,7 @@ def parse_thinking_models(
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id, # type: ignore
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
@@ -8,10 +8,8 @@ import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
to_thread,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
@@ -49,7 +47,6 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
@@ -93,28 +90,29 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._forward_events)
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
@@ -122,10 +120,6 @@ class RunnerSupervisor:
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,10 +142,12 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert cache_length(cache) == len(tokens)
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
# Snapshots should be available for models with non-KV caches
|
||||
assert len(snapshots) > 0
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -159,10 +161,10 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -170,7 +172,7 @@ class TestKVPrefixCacheWithModel:
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
@@ -191,10 +193,12 @@ class TestKVPrefixCacheWithModel:
|
||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = TextGenerationTaskParams(
|
||||
@@ -212,13 +216,12 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_prompt
|
||||
model, long_tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# remaining_tokens should be the suffix after the shared prefix
|
||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
||||
# remaining_tokens covers from snapshot restore position to end
|
||||
assert len(remaining_tokens) >= len(long_tokens) - expected_prefix
|
||||
|
||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
||||
self, model_and_tokenizer
|
||||
@@ -235,15 +238,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
@@ -273,15 +276,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
@@ -298,7 +301,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -328,7 +331,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -352,20 +355,20 @@ class TestKVPrefixCacheWithModel:
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, prompt_tokens
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
assert matched_index == 0
|
||||
# Exact match: remaining_tokens is just the last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
||||
# Exact match: remaining_tokens is just the last token and the one before
|
||||
assert len(remaining_tokens) == 2
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-2:])
|
||||
|
||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -444,7 +447,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -481,7 +484,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -495,7 +498,7 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
@@ -505,19 +508,10 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
# Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)
|
||||
with patch(
|
||||
"exo.worker.engines.mlx.cache.get_memory_used_percentage",
|
||||
return_value=0.95,
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = TextGenerationTaskParams(
|
||||
@@ -529,14 +523,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)
|
||||
|
||||
@@ -34,6 +34,7 @@ TOKENIZER_FILE_PATTERNS = [
|
||||
"added_tokens.json",
|
||||
"tokenizer.model",
|
||||
"tokenization_*.py", # Custom tokenizer implementations
|
||||
"tool_declaration_ts.py", # Dependency of tokenization_kimi.py
|
||||
]
|
||||
|
||||
|
||||
|
||||
53
tests/auto_bench.sh
Executable file
53
tests/auto_bench.sh
Executable file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ $# -lt 1 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
sleep 1
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
}
|
||||
trap 'cleanup' EXIT INT TERM
|
||||
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
|
||||
done
|
||||
wait
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..." 1>&2
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
|
||||
echo "Waiting 30s for cluster setup" 1>&2
|
||||
sleep 30
|
||||
echo "EXO loaded" 1>&2
|
||||
bench_runner="${hosts[0]}"
|
||||
mkdir -p "./bench/$commit"
|
||||
nix run .#exo-get-all-models-on-cluster -- "$bench_runner" | while IFS= read -r model; do
|
||||
echo "running bench for $model" 1>&2
|
||||
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$bench_runner@$bench_runner" "/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-bench -- --model $model --pp 128 4096 --tg 128 --stdout --skip-tensor-ring" >>"./bench/$commit/${model//\//--}.json"
|
||||
echo
|
||||
done
|
||||
36
tests/get_all_models_on_cluster.py
Executable file
36
tests/get_all_models_on_cluster.py
Executable file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from urllib.request import urlopen
|
||||
|
||||
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
|
||||
ts = subprocess.run(
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = next(
|
||||
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None
|
||||
) or sys.exit(f"{h} not found in tailscale")
|
||||
with urlopen(f"http://{ip}:52415/state", timeout=5) as r:
|
||||
data = json.loads(r.read()).get("downloads", {})
|
||||
|
||||
|
||||
def mid(x: dict[str, Any]) -> str | None:
|
||||
for k in (
|
||||
"DownloadCompleted",
|
||||
"shardMetadata",
|
||||
"PipelineShardMetadata",
|
||||
"modelCard",
|
||||
"modelId",
|
||||
):
|
||||
x = x.get(k, {})
|
||||
return cast(str | None, x if x != {} else None)
|
||||
|
||||
|
||||
common = set[str].intersection(
|
||||
*[{m for d in nid if (m := mid(d))} for nid in data.values()]
|
||||
)
|
||||
for c in common:
|
||||
print(c)
|
||||
@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
wait
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
@@ -34,21 +34,13 @@ reset=$'\e[0m'
|
||||
i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
{
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
|
||||
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
} <<'EOF'
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q "$1"
|
||||
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
|
||||
EOF
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52415/models"; do sleep 1; done
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user