Compare commits

..

14 Commits

Author SHA1 Message Date
Evan
103cbdee58 shfmt ignore 2026-02-06 17:08:26 +00:00
Evan
dbcc829625 sdk version? 2026-02-06 16:58:18 +00:00
Evan
30b384e2e6 vendor the apple sdk nix package 2026-02-06 16:40:10 +00:00
Evan
6675feed71 enable MLX_BUILD_CPU in nix built mlx 2026-02-06 16:14:04 +00:00
Evan Quiney
9b5cae3db6 auto bench (#1405)
runs exo_bench remotely with some nice git QoL

## usage
run tests/auto_bench.sh host1 [host2]

exo bench will be run on those hosts and its output saved to
bench/commit_hash/*.json on all models currently downloaded
2026-02-06 15:35:46 +00:00
Jake Hillion
cf7201f91e pyproject: set minimum uv version
The uv.lock is churning constantly as different UV versions bounce it
between revisions. This is made worse by GitHub automatically hiding the
uv.lock changes, meaning it's hard to notice when this went wrong.

Set a minimum version for `uv` in pyproject.toml to fix this. I tried
quite a few versions (not all) and found 0.8.6 sets the revision to 3,
which I believe is the latest. This is from August 2025 so has been
around for a while.

Test plan:

```
jake@maverick:/data/users/jake/repos/exo/ > git checkout main uv.lock
jake@maverick:/data/users/jake/repos/exo/ > nix shell github:nixos/nixpkgs/3dce7f4a77812afd69efcbfe15e5223f98c5c69e#uv --command sh -c 'uv add pip --frozen && uv lock && uv remove pip --frozen && uv lock && uv --version'

Resolved 140 packages in 147ms
Added pip v26.0.1
Resolved 139 packages in 48ms
Removed pip v26.0.1
uv 0.8.6
```
2026-02-06 15:28:10 +00:00
rltakashige
b315035ae0 Add minimax and fix qwen sharding strategies (#1318)
## Motivation

MiniMax tensor sharding does not provide equivalent outputs to running
it as a single node because RMSNorm weights cannot be split without
affecting the output.

Qwen3Next sharding was broken, and something with Qwen3MoE was likely
changed upstream, as several variables no longer exist.

This also ballooned into fixing prefix caching for non-standard models
as Qwen3Next was behaving weirdly.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Worked for a 8 hour long eval at the same performance and a more similar
completion/reasoning token distribution.

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
2026-02-06 13:26:59 +00:00
rltakashige
c8dbbee27b skip tensor ring on bench (#1403)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-06 13:06:59 +00:00
rltakashige
f0107e9670 Fix offline no cache (#1402)
## Motivation

In offline mode, exo complains if there is no caches directory, even if
the files are there.

## Changes

Check safetensors index and the directory structure to build caches
directory.

## Test Plan

### Manual Testing
<img width="2338" height="1102" alt="image"
src="https://github.com/user-attachments/assets/ad769911-399b-4fca-ac80-aeaa046af06b"
/>
<img width="656" height="1668" alt="image"
src="https://github.com/user-attachments/assets/6080986c-3904-4600-a340-8c70f1b33266"
/>
2026-02-06 12:57:01 +00:00
Hunter Bown
9f502793c1 fix: retry downloads on transient errors instead of breaking (#1398)
## Motivation

`download_file_with_retry()` has a `break` in the generic exception
handler that exits the retry loop after the first transient failure.
This means network timeouts, connection resets, and server errors all
cause an immediate download failure — the two remaining retry attempts
never run.

## Changes

**download_utils.py**: Replaced `break` with logging and exponential
backoff in the generic exception handler, matching the existing
rate-limit handler behavior.

Before:
```python
except Exception as e:
    on_connection_lost()
    if attempt == n_attempts - 1:
        raise e
    break  # exits loop immediately
```

After:
```python
except Exception as e:
    on_connection_lost()
    if attempt == n_attempts - 1:
        raise e
    logger.error(f"Download error on attempt {attempt + 1}/{n_attempts} ...")
    logger.error(traceback.format_exc())
    await asyncio.sleep(2.0**attempt)
```

## Why It Works

The `break` statement was bypassing the retry mechanism entirely.
Replacing it with the same log-and-backoff pattern used by the
`HuggingFaceRateLimitError` handler means all 3 attempts are actually
used before giving up. The exponential backoff (1s, 2s) gives transient
issues time to resolve between attempts.

## Test Plan

### Manual Testing
- Downloads that hit transient network errors now retry instead of
failing immediately

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `uv run pytest src/exo/download/tests/ -v` — 11 tests pass

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-06 11:51:54 +00:00
Evan Quiney
c8371349d5 add scripts (#1401)
allow running exo-bench and the headless runner from nix
2026-02-06 11:06:40 +00:00
Evan Quiney
6b907398a4 cancel downloads for deleted instances (#1393)
after deleting an instance, if a given (node_id, model_id) pair doesn't exist in the left over instances, cancel the download of model_id on node_id.
2026-02-05 18:16:43 +00:00
Evan Quiney
572e647908 better cancellation (#1388)
a lot of our cleanup logic wasn't running leading to bad shutdown states

## changes
- added `try: except` blocks around most task groups
- made the runner shutdown code synchronous
- abandon the MpReceiver's recv_async thread on cancellation
- this only occurs during runner shutdown, the queue closing from the
other end should terminate the mp.Queue, cleaning up the thread in its
own time. i could try other methods if this is not sufficient.

## outcome
ctrl-c just works now! minus the tokio panic of course :) no more
hypercorn lifespan errors though!
2026-02-05 15:22:33 +00:00
Evan Quiney
e59ebd986d set exo as the nix default package (#1391)
!!!
2026-02-05 15:15:52 +00:00
72 changed files with 4609 additions and 1771 deletions

3
.gitignore vendored
View File

@@ -32,3 +32,6 @@ dashboard/.svelte-kit/
# host config snapshots
hosts_*.json
.swp
# bench files
bench/**/*.json

View File

@@ -1139,7 +1139,7 @@ class array:
) -> array:
"""See :func:`flatten`."""
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
"""
Equivalent to :func:`reshape` but the shape can be passed either as a
:obj:`tuple` or as separate arguments.
@@ -1222,7 +1222,7 @@ class array:
) -> array:
"""See :func:`swapaxes`."""
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
"""
Equivalent to :func:`transpose` but the axes can be passed either as
a tuple or as separate arguments.

View File

@@ -30,6 +30,9 @@ class Conv1d(Module):
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""
weight: mx.array
groups: int
def __init__(
self,
in_channels: int,

View File

@@ -11,7 +11,10 @@ import mlx.core as mx
class Cache(Protocol):
keys: mx.array
values: mx.array
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
offset: int
def update_and_fetch(
self, keys: mx.array, values: mx.array
) -> tuple[mx.array, mx.array]: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter
@@ -87,6 +90,7 @@ def create_attention_mask(
class _BaseCache(Cache):
keys: mx.array
values: mx.array
offset: int
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.mla import MultiLinear
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module):
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
# kv_b_proj: nn.Linear
embed_q: MultiLinear
unembed_out: MultiLinear
o_proj: nn.Linear
rope: Any

View File

@@ -0,0 +1,114 @@
"""Type stubs for mlx_lm.models.qwen3_next"""
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .switch_layers import SwitchGLU
class Qwen3NextMLP(nn.Module):
gate_proj: nn.Linear
down_proj: nn.Linear
up_proj: nn.Linear
def __init__(self, dim: int, hidden_dim: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Qwen3NextGatedDeltaNet(nn.Module):
hidden_size: int
num_v_heads: int
num_k_heads: int
head_k_dim: int
head_v_dim: int
key_dim: int
value_dim: int
conv_kernel_size: int
conv_dim: int
conv1d: nn.Conv1d
in_proj_qkvz: nn.Linear
in_proj_ba: nn.Linear
dt_bias: mx.array
A_log: mx.array
out_proj: nn.Linear
def __init__(self, config: Any) -> None: ...
def __call__(
self,
inputs: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Qwen3NextAttention(nn.Module):
num_attention_heads: int
num_key_value_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
def __init__(self, args: Any) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Qwen3NextSparseMoeBlock(nn.Module):
norm_topk_prob: bool
num_experts: int
top_k: int
gate: nn.Linear
switch_mlp: SwitchGLU
shared_expert: Qwen3NextMLP
shared_expert_gate: nn.Linear
def __init__(self, args: Any) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Qwen3NextDecoderLayer(nn.Module):
is_linear: bool
linear_attn: Qwen3NextGatedDeltaNet
self_attn: Qwen3NextAttention
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock
def __init__(self, args: Any, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Qwen3NextModel(nn.Module):
embed_tokens: nn.Embedding
layers: list[Qwen3NextDecoderLayer]
norm: nn.RMSNorm
def __init__(self, args: Any) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: Qwen3NextModel
lm_head: nn.Linear
def __init__(self, args: Any) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[Qwen3NextDecoderLayer]: ...

View File

@@ -113,6 +113,10 @@ class TokenizerWrapper:
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
think_start: str | None
think_end: str | None
think_start_id: int | None
think_end_id: int | None
def __init__(
self,

View File

@@ -5,21 +5,21 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).

View File

@@ -431,7 +431,12 @@ def main() -> int:
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Pipeline jaccl is often pointless, skip by default",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
@@ -450,6 +455,7 @@ def main() -> int:
default="bench/results.json",
help="Write raw per-run results JSON to this path.",
)
ap.add_argument("--stdout", action="store_true", help="Write results to stdout")
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
@@ -533,6 +539,16 @@ def main() -> int:
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
@@ -652,7 +668,9 @@ def main() -> int:
time.sleep(5)
if args.json_out:
if args.stdout:
json.dump(all_rows, sys.stdout, indent=2, ensure_ascii=False)
elif args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump(all_rows, f, indent=2, ensure_ascii=False)
logger.debug(f"\nWrote results JSON: {args.json_out}")

View File

@@ -1,7 +1,6 @@
<script lang="ts">
import {
isLoading,
stopGeneration,
sendMessage,
generateImage,
editImage,
@@ -606,92 +605,86 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
{#if loading}
<button
type="button"
onclick={() => stopGeneration()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
aria-label="Stop generation"
>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<svg
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
viewBox="0 0 24 24"
fill="currentColor"
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<rect x="4" y="4" width="16" height="16" rx="2" />
</svg>
<span class="hidden sm:inline">STOP</span>
<span class="sm:hidden">...</span>
</span>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</div>
<!-- Bottom accent line -->

View File

@@ -470,7 +470,6 @@ class AppStore {
messages = $state<Message[]>([]);
currentResponse = $state("");
isLoading = $state(false);
private currentAbortController: AbortController | null = null;
// Performance metrics
ttftMs = $state<number | null>(null); // Time to first token in ms
@@ -1739,11 +1738,9 @@ class AppStore {
return;
}
this.currentAbortController = new AbortController();
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
signal: this.currentAbortController.signal,
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
@@ -1857,7 +1854,6 @@ class AppStore {
"Unknown error",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
@@ -1989,10 +1985,6 @@ class AppStore {
assistantMessageId: string,
errorPrefix = "Failed to get response",
): void {
// Don't show error for user-initiated abort (stop button)
if (error instanceof DOMException && error.name === "AbortError") {
return;
}
if (this.conversationExists(targetConversationId)) {
this.updateConversationMessage(
targetConversationId,
@@ -2034,17 +2026,6 @@ class AppStore {
return null;
}
/**
* Stop the current generation by aborting the HTTP connection.
* This triggers backend cancellation via the mechanism in PR #1276.
*/
stopGeneration() {
if (this.currentAbortController) {
this.currentAbortController.abort();
this.currentAbortController = null;
}
}
/**
* Send a message to the LLM and stream the response
*/
@@ -2192,13 +2173,11 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
this.currentAbortController = new AbortController();
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
signal: this.currentAbortController.signal,
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
@@ -2346,7 +2325,6 @@ class AppStore {
"Failed to get response",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
@@ -2446,13 +2424,11 @@ class AppStore {
};
}
this.currentAbortController = new AbortController();
const response = await fetch("/v1/images/generations", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
signal: this.currentAbortController.signal,
body: JSON.stringify(requestBody),
});
@@ -2601,7 +2577,6 @@ class AppStore {
"Failed to generate image",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
@@ -2730,10 +2705,8 @@ class AppStore {
);
}
this.currentAbortController = new AbortController();
const apiResponse = await fetch("/v1/images/edits", {
method: "POST",
signal: this.currentAbortController.signal,
body: formData,
});
@@ -2843,7 +2816,6 @@ class AppStore {
"Failed to edit image",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
@@ -2972,7 +2944,6 @@ export const hasStartedChat = () => appStore.hasStartedChat;
export const messages = () => appStore.messages;
export const currentResponse = () => appStore.currentResponse;
export const isLoading = () => appStore.isLoading;
export const stopGeneration = () => appStore.stopGeneration();
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;

View File

@@ -83,6 +83,9 @@
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
overlays = [
(final: _: { apple-sdk_26 = final.callPackage ./nix/apple-sdk/package.nix { darwinSdkMajorVersion = "26"; }; })
];
};
treefmt = {
projectRootFile = "flake.nix";
@@ -105,7 +108,10 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
shfmt = {
enable = true;
excludes = [ "nix/apple-sdk/**" ];
};
};
};
@@ -118,9 +124,15 @@
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit (self'.packages) metal-toolchain;
inherit uvLockMlxVersion;
};
default = self'.packages.exo;
sdk-version = pkgs.runCommand "sdk-version" { } ''
mkdir -p $out
echo ${pkgs.apple-sdk_26.version} > $out/version
'';
}
);

View File

@@ -20,7 +20,7 @@ sync-clean:
rust-rebuild:
cargo run --bin stub_gen
just sync-clean
uv sync --reinstall-package exo_pyo3_bindings
build-dashboard:
#!/usr/bin/env bash

0
nix/apple-sdk/README.md Normal file
View File

View File

@@ -0,0 +1,48 @@
{ lib
, fetchFromGitHub
, stdenvNoCC
,
}:
let
CoreSymbolication = stdenvNoCC.mkDerivation (finalAttrs: {
pname = "CoreSymbolication";
version = "0-unstable-2018-06-17";
src = fetchFromGitHub {
repo = "CoreSymbolication";
owner = "matthewbauer";
rev = "24c87c23664b3ee05dc7a5a87d647ae476a680e4";
hash = "sha256-PzvLq94eNhP0+rLwGMKcMzxuD6MlrNI7iT/eV0obtSE=";
};
patches = [
# Add missing symbol definitions needed to build `zlog` in system_cmds.
# https://github.com/matthewbauer/CoreSymbolication/pull/2
../patches/0001-Add-function-definitions-needed-to-build-zlog-in-sys.patch
../patches/0002-Add-CF_EXPORT-To-const-symbols.patch
];
dontBuild = true;
installPhase = ''
mkdir -p "$out/include"
cp *.h "$out/include"
'';
meta = {
description = "Reverse engineered headers for Apple's CoreSymbolication framework";
homepage = "https://github.com/matthewbauer/CoreSymbolication";
license = lib.licenses.mit;
teams = [ lib.teams.darwin ];
platforms = lib.platforms.darwin;
};
});
in
self: super: {
buildPhase = super.buildPhase or "" + ''
mkdir -p System/Library/PrivateFrameworks/CoreSymbolication.framework/Versions/A/Headers
ln -s Versions/Current/Headers System/Library/PrivateFrameworks/CoreSymbolication.framework/Headers
cp '${CoreSymbolication}/include/'*.h System/Library/PrivateFrameworks/CoreSymbolication.framework/Versions/A/Headers
'';
}

View File

@@ -0,0 +1,13 @@
{ lib, config }:
self: super: {
preBuild = super.preBuild or "" + ''
platformPath=$out/Platforms/MacOSX.platform
sdkpath=$platformPath/Developer/SDKs
'';
preInstall = super.preInstall or "" + ''
platformPath=$out/Platforms/MacOSX.platform
sdkpath=$platformPath/Developer/SDKs
'';
}

View File

@@ -0,0 +1,38 @@
{ lib
, fetchurl
, cpio
, pbzx
,
}:
{ urls
, version
, hash
,
}:
fetchurl {
pname = "macOS-SDK";
inherit version urls hash;
recursiveHash = true;
nativeBuildInputs = [
cpio
pbzx
];
postFetch = ''
renamed=$(mktemp -d)/sdk.xar
mv "$downloadedFile" "$renamed"
pbzx "$renamed" | cpio -idm
src=Library/Developer/CommandLineTools/SDKs/MacOSX${lib.versions.majorMinor version}.sdk
# Remove unwanted binaries, man pages, and folders from the SDK.
rm -rf $src/usr/bin $src/usr/share $src/System/Library/Perl
mkdir -p "$out"
cp -rd $src/* "$out"
'';
}

View File

@@ -0,0 +1,10 @@
{ makeSetupHook, sdkVersion }:
self: super: {
passthru = super.passthru or { } // {
privateFrameworksHook = makeSetupHook
{
name = "apple-sdk-private-frameworks-hook";
} ../setup-hooks/add-private-frameworks.sh;
};
}

View File

@@ -0,0 +1,38 @@
let
lockfile = builtins.fromJSON (builtins.readFile ../metadata/apple-oss-lockfile.json);
in
{ lib
, fetchFromGitHub
, stdenvNoCC
, sdkVersion
,
}:
let
sdkinfo = lockfile.${sdkVersion};
in
self: super: {
passthru = super.passthru or { } // {
# Returns the raw source from apple-oss-distributions repo.
# This is mostly useful for copying private headers needed to build other source releases.
#
# Note: The source releases are mostly not used to build the SDK. Unless they can be used to build binaries,
# theyre not used.
sourceRelease =
name:
let
lockinfo = sdkinfo.${name};
in
fetchFromGitHub
{
owner = "apple-oss-distributions";
repo = name;
rev = lockinfo.rev or "${name}-${lockinfo.version}";
inherit (lockinfo) hash;
}
// {
inherit (lockinfo) version;
};
};
}

View File

@@ -0,0 +1,327 @@
{ lib
, stdenvNoCC
, xcodePlatform
, sdkVersion
,
}:
let
inherit (lib.generators) toPlist;
Info = rec {
CFBundleIdentifier = "com.apple.platform.${Name}";
DefaultProperties = {
COMPRESS_PNG_FILES = "NO";
DEPLOYMENT_TARGET_SETTING_NAME = stdenvNoCC.hostPlatform.darwinMinVersionVariable;
STRIP_PNG_TEXT = "NO";
};
Description = if stdenvNoCC.hostPlatform.isMacOS then "macOS" else "iOS";
FamilyIdentifier = lib.toLower xcodePlatform;
FamilyName = Description;
Identifier = CFBundleIdentifier;
MinimumSDKVersion = stdenvNoCC.hostPlatform.darwinMinVersion;
Name = lib.toLower xcodePlatform;
Type = "Platform";
Version = sdkVersion;
};
# These files are all based off of Xcode spec files found in
# /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/Library/Xcode/PrivatePlugIns/IDEOSXSupportCore.ideplugin/Contents/Resources.
# Based off of the "MacOSX Architectures.xcspec" file. All i386 stuff
# is removed because NixPkgs only supports darwin-x86_64 and darwin-arm64.
Architectures = [
{
Identifier = "Standard";
Type = "Architecture";
Name = "Standard Architectures (Apple Silicon, 64-bit Intel)";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD";
}
{
Identifier = "Universal";
Type = "Architecture";
Name = "Universal (Apple Silicon, 64-bit Intel)";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD_32_64_BIT";
}
{
Identifier = "Native";
Type = "Architecture";
Name = "Native Architecture of Build Machine";
ArchitectureSetting = "NATIVE_ARCH_ACTUAL";
}
{
Identifier = "Standard64bit";
Type = "Architecture";
Name = "Apple Silicon, 64-bit Intel";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD_64_BIT";
}
{
Identifier = stdenvNoCC.hostPlatform.darwinArch;
Type = "Architecture";
Name = "Apple Silicon or Intel 64-bit";
}
{
Identifier = "Standard_Including_64_bit";
Type = "Architecture";
Name = "Standard Architectures (including 64-bit)";
RealArchitectures = [
"arm64"
"x86_64"
];
ArchitectureSetting = "ARCHS_STANDARD_INCLUDING_64_BIT";
}
];
# Based off of the "MacOSX Package Types.xcspec" file. Only keep the
# bare minimum needed.
PackageTypes = [
{
Identifier = "com.apple.package-type.mach-o-executable";
Type = "PackageType";
Name = "Mach-O Executable";
DefaultBuildSettings = {
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "compiled.mach-o.executable";
Name = "$(EXECUTABLE_NAME)";
};
}
{
Identifier = "com.apple.package-type.mach-o-objfile";
Type = "PackageType";
Name = "Mach-O Object File";
DefaultBuildSettings = {
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "compiled.mach-o.objfile";
Name = "$(EXECUTABLE_NAME)";
};
}
{
Identifier = "com.apple.package-type.mach-o-dylib";
Type = "PackageType";
Name = "Mach-O Dynamic Library";
DefaultBuildSettings = {
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "compiled.mach-o.dylib";
Name = "$(EXECUTABLE_NAME)";
};
}
{
Identifier = "com.apple.package-type.static-library";
Type = "PackageType";
Name = "Mach-O Static Library";
DefaultBuildSettings = {
EXECUTABLE_PREFIX = "lib";
EXECUTABLE_SUFFIX = ".a";
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_PATH = "$(EXECUTABLE_NAME)";
};
ProductReference = {
FileType = "archive.ar";
Name = "$(EXECUTABLE_NAME)";
IsLaunchable = "NO";
};
}
{
Identifier = "com.apple.package-type.wrapper";
Type = "PackageType";
Name = "Wrapper";
DefaultBuildSettings = {
WRAPPER_SUFFIX = ".bundle";
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
CONTENTS_FOLDER_PATH = "$(WRAPPER_NAME)/Contents";
EXECUTABLE_NAME = "$(EXECUTABLE_PREFIX)$(PRODUCT_NAME)$(EXECUTABLE_VARIANT_SUFFIX)$(EXECUTABLE_SUFFIX)";
EXECUTABLE_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/MacOS";
EXECUTABLE_PATH = "$(EXECUTABLE_FOLDER_PATH)/$(EXECUTABLE_NAME)";
INFOPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/Info.plist";
INFOSTRINGS_PATH = "$(LOCALIZED_RESOURCES_FOLDER_PATH)/InfoPlist.strings";
PKGINFO_PATH = "$(CONTENTS_FOLDER_PATH)/PkgInfo";
PBDEVELOPMENTPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/pbdevelopment.plist";
VERSIONPLIST_PATH = "$(CONTENTS_FOLDER_PATH)/version.plist";
PUBLIC_HEADERS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Headers";
PRIVATE_HEADERS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/PrivateHeaders";
EXECUTABLES_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Executables";
FRAMEWORKS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Frameworks";
SHARED_FRAMEWORKS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/SharedFrameworks";
SHARED_SUPPORT_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/SharedSupport";
UNLOCALIZED_RESOURCES_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/Resources";
LOCALIZED_RESOURCES_FOLDER_PATH = "$(UNLOCALIZED_RESOURCES_FOLDER_PATH)/$(DEVELOPMENT_LANGUAGE).lproj";
DOCUMENTATION_FOLDER_PATH = "$(LOCALIZED_RESOURCES_FOLDER_PATH)/Documentation";
PLUGINS_FOLDER_PATH = "$(CONTENTS_FOLDER_PATH)/PlugIns";
SCRIPTS_FOLDER_PATH = "$(UNLOCALIZED_RESOURCES_FOLDER_PATH)/Scripts";
};
ProductReference = {
FileType = "wrapper.cfbundle";
Name = "$(WRAPPER_NAME)";
IsLaunchable = "NO";
};
}
{
Identifier = "com.apple.package-type.wrapper.application";
Type = "PackageType";
BasedOn = "com.apple.package-type.wrapper";
Name = "Application Wrapper";
DefaultBuildSettings = {
GENERATE_PKGINFO_FILE = "YES";
};
ProductReference = {
FileType = "wrapper.application";
Name = "$(WRAPPER_NAME)";
IsLaunchable = "YES";
};
}
];
# Based off of the "MacOSX Product Types.xcspec" file. All
# bundles/wrapper are removed, because we prefer dynamic products in
# NixPkgs.
ProductTypes = [
{
Identifier = "com.apple.product-type.tool";
Type = "ProductType";
Name = "Command-line Tool";
PackageTypes = [ "com.apple.package-type.mach-o-executable" ];
}
{
Identifier = "com.apple.product-type.objfile";
Type = "ProductType";
Name = "Object File";
PackageTypes = [ "com.apple.package-type.mach-o-objfile" ];
}
{
Identifier = "com.apple.product-type.library.dynamic";
Type = "ProductType";
Name = "Dynamic Library";
PackageTypes = [ "com.apple.package-type.mach-o-dylib" ];
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(EXECUTABLE_NAME)";
MACH_O_TYPE = "mh_dylib";
REZ_EXECUTABLE = "YES";
EXECUTABLE_SUFFIX = ".$(EXECUTABLE_EXTENSION)";
EXECUTABLE_EXTENSION = "dylib";
DYLIB_COMPATIBILITY_VERSION = "1";
DYLIB_CURRENT_VERSION = "1";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "debugging";
GCC_INLINES_ARE_PRIVATE_EXTERN = "YES";
CODE_SIGNING_ALLOWED = "YES";
CODE_SIGNING_REQUIRED = "NO";
};
}
{
Identifier = "com.apple.product-type.library.static";
Type = "ProductType";
Name = "Static Library";
PackageTypes = [ "com.apple.package-type.static-library" ];
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(EXECUTABLE_NAME)";
MACH_O_TYPE = "staticlib";
REZ_EXECUTABLE = "YES";
EXECUTABLE_PREFIX = "lib";
EXECUTABLE_SUFFIX = ".$(EXECUTABLE_EXTENSION)";
EXECUTABLE_EXTENSION = "a";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "debugging";
SEPARATE_STRIP = "YES";
CLANG_ENABLE_MODULE_DEBUGGING = "NO";
};
}
{
Type = "ProductType";
Identifier = "com.apple.product-type.bundle";
Name = "Bundle";
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(WRAPPER_NAME)";
MACH_O_TYPE = "mh_bundle";
WRAPPER_PREFIX = "";
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
WRAPPER_EXTENSION = "bundle";
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "non-global";
};
PackageTypes = [ "com.apple.package-type.wrapper" ];
IsWrapper = "YES";
HasInfoPlist = "YES";
HasInfoPlistStrings = "YES";
}
{
Identifier = "com.apple.product-type.application";
Type = "ProductType";
BasedOn = "com.apple.product-type.bundle";
Name = "Application";
DefaultBuildProperties = {
MACH_O_TYPE = "mh_execute";
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
WRAPPER_EXTENSION = "app";
};
PackageTypes = [ "com.apple.package-type.wrapper.application" ];
}
{
Type = "ProductType";
Identifier = "com.apple.product-type.framework";
Name = "Bundle";
DefaultBuildProperties = {
FULL_PRODUCT_NAME = "$(WRAPPER_NAME)";
MACH_O_TYPE = "mh_bundle";
WRAPPER_PREFIX = "";
WRAPPER_SUFFIX = ".$(WRAPPER_EXTENSION)";
WRAPPER_EXTENSION = "bundle";
WRAPPER_NAME = "$(WRAPPER_PREFIX)$(PRODUCT_NAME)$(WRAPPER_SUFFIX)";
FRAMEWORK_FLAG_PREFIX = "-framework";
LIBRARY_FLAG_PREFIX = "-l";
LIBRARY_FLAG_NOSPACE = "YES";
STRIP_STYLE = "non-global";
};
PackageTypes = [ "com.apple.package-type.wrapper" ];
IsWrapper = "YES";
HasInfoPlist = "YES";
HasInfoPlistStrings = "YES";
}
];
ToolchainInfo = {
Identifier = "com.apple.dt.toolchain.XcodeDefault";
};
in
{
"Info.plist" = builtins.toFile "Info.plist" (toPlist { escape = true; } Info);
"ToolchainInfo.plist" = builtins.toFile "ToolchainInfo.plist" (
toPlist { escape = true; } ToolchainInfo
);
"Architectures.xcspec" = builtins.toFile "Architectures.xcspec" (
toPlist { escape = true; } Architectures
);
"PackageTypes.xcspec" = builtins.toFile "PackageTypes.xcspec" (
toPlist { escape = true; } PackageTypes
);
"ProductTypes.xcspec" = builtins.toFile "ProductTypes.xcspec" (
toPlist { escape = true; } ProductTypes
);
}

View File

@@ -0,0 +1,40 @@
let
removedDylibs = [
# corecrypto is available under a very restrictive license (effectively: non-free, cant use).
# Without the headers and not being able to use corecrypto due to its license, its not very useful.
# Stubs are included in the SDK for all dylibs, including corecrypto. They should be removed.
"/usr/lib/system/libcorecrypto.dylib"
];
in
{ lib
, jq
, llvm
,
}:
self: super: {
nativeBuildInputs = super.nativeBuildInputs or [ ] ++ [
jq
llvm
];
buildPhase = super.buildPhase or "" + ''
echo "Removing the following dylibs from the libSystem reexported libraries list: ${lib.escapeShellArg (lib.concatStringsSep ", " removedDylibs)}"
for libSystem in libSystem.B.tbd libSystem.B_asan.tbd; do
# tbd-v5 is a JSON-based format, which can be manipulated by `jq`.
llvm-readtapi --filetype=tbd-v5 usr/lib/$libSystem \
| jq --argjson libs ${lib.escapeShellArg (builtins.toJSON removedDylibs)} '
if .libraries then
.libraries[] |= select(.install_names[] | any([.] | inside($libs)) | not)
else
.
end
| .main_library.reexported_libraries[].names[] |= select([.] | inside($libs) | not)
' > usr/lib/$libSystem~
# Convert libSystem back to tbd-v4 because not all tooling supports the JSON-based format yet.
llvm-readtapi --filetype=tbd-v4 usr/lib/$libSystem~ -o usr/lib/$libSystem
rm usr/lib/$libSystem~
done
'';
}

View File

@@ -0,0 +1,74 @@
{ lib
, cups
, darwin
, db
, libiconv
, ncurses
, stdenv
, stdenvNoCC
, xcbuild
,
}:
let
# CUPS has too many dependencies to build as part of the Darwin bootstrap. Its also typically taken as an explicit
# dependency by other packages, so building only the headers (to satisfy other SDK headers) should be okay.
cupsHeaders = darwin.bootstrapStdenv.mkDerivation {
pname = "${lib.getName cups}-headers";
version = lib.getVersion cups;
inherit (cups) src;
patches = cups.patches or [ ];
strictDeps = true;
dontBuild = true;
buildInputs = [ darwin.libresolv ]; # The `configure` script requires libresolv headers.
# CUPSs configure script fails to find `ar` when cross-compiling.
configureFlags = [ "ac_cv_path_AR=${stdenv.cc.targetPrefix}ar" ];
installTargets = [ "install-headers" ];
__structuredAttrs = true;
meta = {
inherit (cups.meta)
homepage
description
license
maintainers
platforms
;
};
};
in
self: super: {
# These packages are propagated only because other platforms include them in their libc (or otherwise by default).
# Reducing the number of special cases required to support Darwin makes supporting it easier for package authors.
propagatedBuildInputs =
super.propagatedBuildInputs or [ ]
++ [
libiconv
darwin.libresolv
darwin.libsbuf
# Shipped with the SDK only as a library with no headers
(lib.getLib darwin.libutil)
]
# x86_64-darwin links the object files from Csu when targeting very old releases
++ lib.optionals stdenvNoCC.hostPlatform.isx86_64 [ darwin.Csu ];
# The Darwin module for Swift requires certain headers to be included in the SDK (and not just be propagated).
buildPhase = super.buildPhase or "" + ''
for header in '${lib.getDev libiconv}/include/'* '${lib.getDev ncurses}/include/'* '${cupsHeaders}/include/'*; do
ln -s "$header" "usr/include/$(basename "$header")"
done
'';
# Exported to allow the headers to pass the requisites check in the stdenv bootstrap.
passthru = (super.passthru or { }) // {
cups-headers = cupsHeaders;
};
}

View File

@@ -0,0 +1,53 @@
{ lib
, pkgsBuildHost
, stdenv
, stdenvNoCC
, sdkVersion
,
}:
let
plists = import ./plists.nix {
inherit lib stdenvNoCC sdkVersion;
xcodePlatform = if stdenvNoCC.hostPlatform.isMacOS then "MacOSX" else "iPhoneOS";
};
inherit (pkgsBuildHost) darwin cctools xcbuild;
in
self: super: {
propagatedNativeBuildInputs = super.propagatedNativeBuildInputs or [ ] ++ [ xcbuild.xcrun ];
postInstall = super.postInstall or "" + ''
specspath=$out/Library/Xcode/Specifications
toolchainsPath=$out/Toolchains/XcodeDefault.xctoolchain
mkdir -p "$specspath" "$toolchainsPath"
# xcbuild expects to find things relative to the plist locations. If these are linked instead of copied,
# it wont find any platforms or SDKs.
cp '${plists."Info.plist"}' "$platformPath/Info.plist"
cp '${plists."ToolchainInfo.plist"}' "$toolchainsPath/ToolchainInfo.plist"
for spec in '${xcbuild}/Library/Xcode/Specifications/'*; do
ln -s "$spec" "$specspath/$(basename "$spec")"
done
cp '${plists."Architectures.xcspec"}' "$specspath/Architectures.xcspec"
cp '${plists."PackageTypes.xcspec"}' "$specspath/PackageTypes.xcspec"
cp '${plists."ProductTypes.xcspec"}' "$specspath/ProductTypes.xcspec"
mkdir -p "$out/usr/bin"
ln -s '${xcbuild.xcrun}/bin/xcrun' "$out/usr/bin/xcrun"
# Include `libtool` in the toolchain, so `xcrun -find libtool` can find it without requiring `cctools.libtool`
# as a `nativeBuildInput`.
mkdir -p "$toolchainsPath/usr/bin"
if [ -e '${cctools.libtool}/bin/${stdenv.cc.targetPrefix}libtool' ]; then
ln -s '${cctools.libtool}/bin/${stdenv.cc.targetPrefix}libtool' "$toolchainsPath/usr/bin/libtool"
fi
# Include additional binutils required by some packages (such as Chromium).
for tool in lipo nm otool size strip; do
if [ -e '${darwin.binutils-unwrapped}/bin/${stdenv.cc.targetPrefix}'$tool ]; then
ln -s '${darwin.binutils-unwrapped}/bin/${stdenv.cc.targetPrefix}'$tool "$toolchainsPath/usr/bin/$tool"
fi
done
'';
}

View File

@@ -0,0 +1,24 @@
let
disallowedPackages = builtins.fromJSON (builtins.readFile ../metadata/disallowed-packages.json);
in
{ lib
, jq
, stdenv
,
}:
self: super: {
# Remove headers and stubs for packages that are available in nixpkgs.
buildPhase = super.buildPhase or "" + ''
${lib.concatMapStringsSep "\n" (
pkg:
lib.concatLines (
[ ''echo "Removing headers and libraries from ${pkg.package}"'' ]
++ (map (header: "rm -rf -- usr/include/${header}") pkg.headers or [ ])
++ (map (framework: "rm -rf -- System/Library/Frameworks/${framework}") pkg.frameworks or [ ])
++ (map (library: "rm -rf -- usr/lib/${library}") pkg.libraries or [ ])
)
) disallowedPackages}
'';
}

View File

@@ -0,0 +1,9 @@
{}:
self: super: {
buildPhase = ''
runHook preBuild
${super.buildPhase or ""}
runHook postBuild
'';
}

View File

@@ -0,0 +1,536 @@
{
"14.4": {
"CarbonHeaders": {
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
"version": "18.1"
},
"CommonCrypto": {
"hash": "sha256-/VoOR9wJuKnmGE1CWGGXxX8SpmALHnEooNTa3QM+ITc=",
"version": "600028.100.1"
},
"IOAudioFamily": {
"hash": "sha256-VSk3jvsITJugtL67Qt0m4qJ879i7Fj6B/NGBFVCwpiU=",
"version": "540.3"
},
"IOBDStorageFamily": {
"hash": "sha256-UgLMsQBe1QLzlbScmPmASBN7VH4YBmNOUX2CEDezjmE=",
"version": "22"
},
"IOCDStorageFamily": {
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
"version": "61"
},
"IODVDStorageFamily": {
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
"version": "45"
},
"IOFWDVComponents": {
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
"version": "208"
},
"IOFireWireAVC": {
"hash": "sha256-IUytBKhhCgg0vtI+7q8d5kxpOUgO3tQD7TMy++jrorc=",
"version": "431"
},
"IOFireWireFamily": {
"hash": "sha256-W0KOF4hkA7kFOnL1ThAeFU/YlhFVqoqk9uzGjcBppX8=",
"version": "487"
},
"IOFireWireSBP2": {
"hash": "sha256-bItnRQIaGUxMyiU0q+4N8e5+jYiDEOUPmsrKhBFXvok=",
"version": "445"
},
"IOFireWireSerialBusProtocolTransport": {
"hash": "sha256-P7egeaD9SSa+YyrIRzM44gILKbIL7vezXK3M6q3MBOI=",
"version": "260"
},
"IOGraphics": {
"hash": "sha256-Ag37fd3tZJLXLVq1yzHOCWGOYYfwwTkC8hnvNaTEaWg=",
"version": "598"
},
"IOHIDFamily": {
"hash": "sha256-fmYTJsquAOBwzsgRmqPyjSJJi1hGcfnMmqLIcTe8W1s=",
"version": "2031.100.16"
},
"IOKitUser": {
"hash": "sha256-1bqRiLvyr2GQfbWwhXHXXIOtIka9YDw5GbKV6bd2k4k=",
"version": "100076.101.1"
},
"IONetworkingFamily": {
"hash": "sha256-J3cLeWKrQ8ypIaqgwRH9eU5JbjEDBVoezj3a2Lvwu5k=",
"version": "177"
},
"IOSerialFamily": {
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
"version": "93"
},
"IOStorageFamily": {
"hash": "sha256-cllpJX11c3CX8zEYdOT2TC63sx7NUAHh33yRHhrG2Ro=",
"version": "315"
},
"IOUSBFamily": {
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
"version": "630.4.5"
},
"Libc": {
"hash": "sha256-fxBM4KbPwQNVEJl7PCKP+1nUk9Oce/O2+0lVBxyngew=",
"version": "1592.100.35"
},
"Libinfo": {
"hash": "sha256-zZr6Mmou8Q+G6/wS+k0k7R+XirB94TNCUGS5dhi96ZE=",
"version": "583.0.1"
},
"Libm": {
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
"version": "2026"
},
"Libnotify": {
"hash": "sha256-7X+6S3C7ZOTXJUeDXOOg5EmoZyLZvtE06x3Is0TGgSU=",
"version": "317.100.2"
},
"Librpcsvc": {
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
"version": "31"
},
"Libsystem": {
"hash": "sha256-HsItciWrwyXujQ2hwqzv0JKOkkuynXYIqejLAEPJbMc=",
"version": "1345.100.2"
},
"OpenDirectory": {
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
"version": "146"
},
"Security": {
"hash": "sha256-NgTGbaw5JkpboDQpt1fSgUr9NYGS+bIOrEMQX7mLAME=",
"version": "61123.100.169"
},
"architecture": {
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
"version": "282"
},
"configd": {
"hash": "sha256-+3xesYxqfsNjWCW3T87OA7+Z1hBqmGEh/I8kP8Ajbso=",
"version": "1300.100.9"
},
"copyfile": {
"hash": "sha256-rSCTgzdHr7QmnPk9rJ9P4fOAolnEQv8PHfgAY+qA0s4=",
"version": "196.100.4"
},
"dtrace": {
"hash": "sha256-04Q35rCKnM5Csv5poFJKpK0VplWq4hvy251/Cb2Kl80=",
"version": "401.100.3"
},
"dyld": {
"hash": "sha256-6P/Da6xP19vmaCROoYv9pl7DaW3/U+qZBJT8PD33bn0=",
"version": "1160.6"
},
"eap8021x": {
"hash": "sha256-Ky6KSlJhyX1NRufGhVBcp+ZFmqYrAxwC/5QvJhC2PhU=",
"version": "354.100.3"
},
"hfs": {
"hash": "sha256-+YUVOttZU7C8I14CC6t3ZH2KxAjjTA2nB0y5bPgLxZM=",
"version": "650.0.2"
},
"launchd": {
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
"version": "842.1.4"
},
"libclosure": {
"hash": "sha256-M/jnIHzKYvdFCO0tJ1JXiD/UcZtJhLIoulaCQQUbn30=",
"version": "90"
},
"libdispatch": {
"hash": "sha256-igqIA5DMVHjG30WMHZZpYY7LRM9hZyMWItD+UxeTehY=",
"version": "1477.100.9"
},
"libmalloc": {
"hash": "sha256-Sh4/z7lGWRMldOPURkP5vLOAb5Ou6AUsVJEWz9wk9hI=",
"version": "521.100.59"
},
"libplatform": {
"hash": "sha256-gojt3sWOr7XO2yYI/B1CmNLTPFieSfoNtlOgQahOCok=",
"version": "316.100.10"
},
"libpthread": {
"hash": "sha256-phjfN8+IU8ibPsflR6LktnSi3giy89ghI+cFyrhiQNo=",
"version": "519.101.1"
},
"mDNSResponder": {
"hash": "sha256-0ECbWeMnIRTsi03BeBEe5boyR/84JJPbxzPQze8hHSA=",
"version": "2200.100.94.0.2"
},
"objc4": {
"hash": "sha256-eUVSpbyTEOMEdHoxSv6lZIZwB+cW/YWIaTZTcHgGOjo=",
"version": "912.3"
},
"ppp": {
"hash": "sha256-8+QUA79sHf85yvGSPE9qCmGsrZDT3NZnbgZVroJw/Hg=",
"version": "1016"
},
"removefile": {
"hash": "sha256-L6I0u8S3h3uV1veKA5HvkSebbBCd78ymlf//KWbebZo=",
"version": "70.100.4"
},
"xnu": {
"hash": "sha256-j5Ep1RX5DTJqTGszrF4d/JtzUqZ6nA6XoExqcIQ0RVQ=",
"version": "10063.101.15"
}
},
"15.5": {
"CarbonHeaders": {
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
"version": "18.1"
},
"CommonCrypto": {
"hash": "sha256-+qAwL6+s7di9cX/qXtapLkjCFoDuZaSYltRJEG4qekM=",
"version": "600035"
},
"IOAudioFamily": {
"hash": "sha256-VSk3jvsITJugtL67Qt0m4qJ879i7Fj6B/NGBFVCwpiU=",
"version": "600.2"
},
"IOBDStorageFamily": {
"hash": "sha256-s8hTwX0jq2iPULfBLUwpzqtszWuvJrrLGbmrKa/fY4U=",
"version": "24"
},
"IOCDStorageFamily": {
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
"version": "62"
},
"IODVDStorageFamily": {
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
"version": "46"
},
"IOFWDVComponents": {
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
"version": "208"
},
"IOFireWireAVC": {
"hash": "sha256-qR9lSTa7PN5Z9Nis4tfuXlcZGMIU48dete/NPD0UBbE=",
"version": "434"
},
"IOFireWireFamily": {
"hash": "sha256-hmErAXjLWIelqJaCrB8J4IiIxyB7S6EHFY+AY9YhmKQ=",
"version": "490"
},
"IOFireWireSBP2": {
"hash": "sha256-Xk+PDnUaO9q46nQwHwTKf/QXtGclfs0wTWiUbcV7e4s=",
"version": "452"
},
"IOFireWireSerialBusProtocolTransport": {
"hash": "sha256-P7egeaD9SSa+YyrIRzM44gILKbIL7vezXK3M6q3MBOI=",
"version": "261"
},
"IOGraphics": {
"hash": "sha256-iysZE42mOKZbFxSZBNspaBTCRKEKK38DFGBxZWQxZxI=",
"version": "599"
},
"IOHIDFamily": {
"hash": "sha256-gEYPyjXgQ2ABGufCKPjmzMdNRLxhELkCvOURCokyTO4=",
"version": "2115.100.21"
},
"IOKitUser": {
"hash": "sha256-p32U+jHfwA/tqnjF4p1BmojghEXK8KxiflW3IHs2iIY=",
"version": "100150.120.2"
},
"IONetworkingFamily": {
"hash": "sha256-gZ7Dkk4Iu7AV9K2ioqSeJ1W7bTNxv77bmT18iv3ljLg=",
"version": "185"
},
"IOSerialFamily": {
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
"version": "93"
},
"IOStorageFamily": {
"hash": "sha256-/0H0tqWUWkgYigYypucbc7lOCFYDuukwF9fvLEOhwOk=",
"version": "323"
},
"IOUSBFamily": {
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
"version": "630.4.5"
},
"Libc": {
"hash": "sha256-nWDokN0Vr5pUyNGculnDOah9RNgHiWr3S13RSQLmZrc=",
"version": "1698.100.8"
},
"Libinfo": {
"hash": "sha256-UI5mGvzZ6BPafGYD6CrNAJAKjeJLB6urAS2lpB6X/Ec=",
"version": "597"
},
"Libm": {
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
"version": "2026"
},
"Libnotify": {
"hash": "sha256-GDYMVi1034f9empq0YOuumQp/BDJ7phTb0Zl4KTY9xg=",
"version": "342"
},
"Librpcsvc": {
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
"version": "31"
},
"Libsystem": {
"hash": "sha256-nawWJiu2IJ34ek5iOX6CrlqMzev7TuJpUkvDp30ZQ/U=",
"version": "1351"
},
"OpenDirectory": {
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
"version": "146"
},
"Security": {
"hash": "sha256-ZOrOOCk+hZbzDilzkihpQfsDpzV3Ul4zy6fpFRWUQHw=",
"version": "61439.120.27"
},
"architecture": {
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
"version": "282"
},
"configd": {
"hash": "sha256-ZdUq1SrOwB88Lx68ekrA4zeVsLDZz4TAJywNnF+uAzY=",
"version": "1351.120.3"
},
"copyfile": {
"hash": "sha256-rLqT6e44W2ohgwUXREmiOyJBYCrV3gRLbtVnbUq60xc=",
"version": "221.121.1"
},
"dtrace": {
"hash": "sha256-iNEZyxK3DmEwO3gzrfvCaVZSEuuOMQm5IG/6FodPNdI=",
"version": "411"
},
"dyld": {
"hash": "sha256-4OOghgUYyMJbsTe96fiWCndTJ1BS94rK9v6Kqn/ooYs=",
"version": "1285.19"
},
"eap8021x": {
"hash": "sha256-Kx/wwnt108hDm0qQPyTNbZ8KoHkD5m7L4yb5qjSuQjI=",
"version": "365.120.2"
},
"hfs": {
"hash": "sha256-5/3Ycp3cKqlgAl1kjBmbF5tFlfJYQS5rbrbk4SS66b8=",
"version": "683.120.3"
},
"launchd": {
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
"version": "842.1.4"
},
"libclosure": {
"hash": "sha256-pvwfcbeEJmTEPdt6/lgVswiabLRG+sMN6VT5FwG7C4Q=",
"version": "96"
},
"libdispatch": {
"hash": "sha256-jTp2DolOOCQPBt1HRotkmPnKgQ2LGgniEqeHoM+vlKg=",
"version": "1521.120.4"
},
"libmalloc": {
"hash": "sha256-d9AVHSYTqHDlgctv8Hh4HAYW53MJelj4F8LWPsjrsws=",
"version": "715.120.13"
},
"libplatform": {
"hash": "sha256-gpijoTMvdkM0PdG8gyIllOJlh/MtTc4ro9ODDAhN6gM=",
"version": "349"
},
"libpthread": {
"hash": "sha256-N+MMXdbthsxauTTfZ5ElUs39dVH+Chn1yyU6pObZpkU=",
"version": "536"
},
"mDNSResponder": {
"hash": "sha256-ILx12PRxj/+VqfpCCErJFEJXFI9yzTh4g+FK0UCenIE=",
"version": "2600.120.12"
},
"objc4": {
"hash": "sha256-DMxa25gXjKCkiDnVJ/8SyJUjaBlmBGABg8EfCHcmTj0=",
"version": "940.4"
},
"ppp": {
"hash": "sha256-8+QUA79sHf85yvGSPE9qCmGsrZDT3NZnbgZVroJw/Hg=",
"version": "1016"
},
"removefile": {
"hash": "sha256-Z5UD0mk/s80CQB0PZWDzSl2JWXmnVmwUvlNb28+hR3k=",
"version": "81"
},
"xnu": {
"hash": "sha256-o4tCuCAIgAYg/Li3wTs12mVWr5C/4vbwu1zi+kJ9d6w=",
"version": "11417.121.6"
}
},
"26.0": {
"CarbonHeaders": {
"hash": "sha256-nIPXnLr21yVnpBhx9K5q3l/nPARA6JL/dED08MeyhP8=",
"version": "18.1"
},
"CommonCrypto": {
"hash": "sha256-+qAwL6+s7di9cX/qXtapLkjCFoDuZaSYltRJEG4qekM=",
"version": "600035"
},
"IOAudioFamily": {
"hash": "sha256-A3iiAjjP29VdjMj40tLS5Q/ni4qeh9bBpnmNzeG2pIY=",
"version": "700.2"
},
"IOBDStorageFamily": {
"hash": "sha256-OcQUJ3nEfrpvWX/npnedJ4PECIGWFSLiM0PKoiH911w=",
"version": "26"
},
"IOCDStorageFamily": {
"hash": "sha256-p/2qM5zjXFDRb/DISpEHxQEdvmuLlRGt/Ygc71Yu2rI=",
"version": "62"
},
"IODVDStorageFamily": {
"hash": "sha256-1Sa8aZBGNtqJBNHva+YXxET6Wcdm2PgVrTzYT/8qrN4=",
"version": "46"
},
"IOFWDVComponents": {
"hash": "sha256-WkfkWnzRupEh20U7vjsTta89clhus6GTkOpXQWXw/bM=",
"version": "208"
},
"IOFireWireAVC": {
"hash": "sha256-qR9lSTa7PN5Z9Nis4tfuXlcZGMIU48dete/NPD0UBbE=",
"version": "436"
},
"IOFireWireFamily": {
"hash": "sha256-hmErAXjLWIelqJaCrB8J4IiIxyB7S6EHFY+AY9YhmKQ=",
"version": "492"
},
"IOFireWireSBP2": {
"hash": "sha256-Xk+PDnUaO9q46nQwHwTKf/QXtGclfs0wTWiUbcV7e4s=",
"version": "454"
},
"IOFireWireSerialBusProtocolTransport": {
"hash": "sha256-cM/VFhVWNVwdJYk+mme0UYttQd7eJwd7Hlo7KNRyHY0=",
"version": "262"
},
"IOGraphics": {
"hash": "sha256-iysZE42mOKZbFxSZBNspaBTCRKEKK38DFGBxZWQxZxI=",
"version": "599"
},
"IOHIDFamily": {
"hash": "sha256-YLnabX90g4Q8LxjwVuJF6KODCDxychWV+VJaNG9d8fI=",
"version": "2222.0.24"
},
"IOKitUser": {
"hash": "sha256-ngwi8YMUqE0q8j7Lr5cqJwi2V+IDu3ie3bduotHIUJU=",
"version": "100222.0.4"
},
"IONetworkingFamily": {
"hash": "sha256-ZF5ML41Y1l1liQn32qTkcl4mMvx9Xdizb9VgvTzVTL4=",
"version": "186"
},
"IOSerialFamily": {
"hash": "sha256-wVS4QTx6MBOS0VrwyCZ3s5Usezwaf8rWzmNnfdDTXTU=",
"version": "93"
},
"IOStorageFamily": {
"hash": "sha256-1FKSF622qeXPGngA3UmQ2M/IU1pdlMoYBPbXytUFDaQ=",
"version": "331"
},
"IOUSBFamily": {
"hash": "sha256-Z0E3TfKP49toYo1Fo9kElRap8CZ+mVDHy5RIexgJTpA=",
"version": "630.4.5"
},
"Libc": {
"hash": "sha256-k+HQ+qgye0ORFm0hU8WzE4ysbbEoFZ7wcbVl5giDH/E=",
"version": "1725.0.11"
},
"Libinfo": {
"hash": "sha256-4InBEPi0n2EMo/8mIBib1Im4iTKRcRJ4IlAcLCigVGk=",
"version": "600"
},
"Libm": {
"hash": "sha256-p4BndAag9d0XSMYWQ+c4myGv5qXbKx5E1VghudSbpTk=",
"version": "2026"
},
"Libnotify": {
"hash": "sha256-p8cJZlBYOFmI1NDHXGYjgcv8z9Ldc1amZuYlxxJfeVY=",
"version": "344.0.1"
},
"Librpcsvc": {
"hash": "sha256-UWYdCQ9QsBqwM01bWr+igINAHSdSluB/FrOclC5AjTI=",
"version": "31"
},
"Libsystem": {
"hash": "sha256-/NlSwPaoTVx+bl9hYsfz3C5MuLdqGv4vdAh0KDbDKmY=",
"version": "1356"
},
"OpenDirectory": {
"hash": "sha256-6fSl8PasCZSBfe0ftaePcBuSEO3syb6kK+mfDI6iR7A=",
"version": "146"
},
"Security": {
"hash": "sha256-oxOvZsDoNYZNiWf+MASHrR4Q2o5oaqvK2We51hH7CO8=",
"version": "61901.0.87.0.1"
},
"architecture": {
"hash": "sha256-PRNUrhzSOrwmxSPkKmV0LV7yEIik65sdkfKdBqcwFhU=",
"version": "282"
},
"configd": {
"hash": "sha256-58or+OQP788UgQKO7Y8k8pY/enaSqH971ks7xCPu8fA=",
"version": "1385.0.7"
},
"copyfile": {
"hash": "sha256-I9uDi5BDQKa7mO3XpHxv0d6PiROW2ueZ3vGfrsG0OJo=",
"version": "230.0.1.0.1"
},
"dtrace": {
"hash": "sha256-5HpH6Cg8vWWzOX5ADD//izKDvqGnzV05Giju8lmGeyA=",
"version": "413"
},
"dyld": {
"hash": "sha256-jzoFLwbms0rUwzyjYif/r6Rmr4kyn+as/bhc4paEPeY=",
"version": "1323.3"
},
"eap8021x": {
"hash": "sha256-17bseWT4OWMA8hF+YSDDjxhVyJpbpP2xwv8dGti1YoM=",
"version": "368.0.3"
},
"hfs": {
"hash": "sha256-OkgqZ03gwn2hTuHxZrPDmQOrY4Dwu7MrX+BfG+PTgvE=",
"version": "704.0.3.0.2"
},
"launchd": {
"hash": "sha256-8mW9bnuHmRXCx9py8Wy28C5b2QPICW0rlAps5njYa00=",
"version": "842.1.4"
},
"libclosure": {
"hash": "sha256-pvwfcbeEJmTEPdt6/lgVswiabLRG+sMN6VT5FwG7C4Q=",
"version": "96"
},
"libdispatch": {
"hash": "sha256-L0+Ho9dAlMXVpqFEGIcIMsJc0gULckRulUImNEZe5MU=",
"version": "1542.0.4"
},
"libmalloc": {
"hash": "sha256-482hgm1ESr3LWC/JhuQNGNu9smsa2Eap49/eH+YNAio=",
"version": "792.1.1"
},
"libplatform": {
"hash": "sha256-wGZ2Im81mRXx6epgj/tbOJpg89CEbAr0Z8oFEpkyNMU=",
"version": "359.1.2"
},
"libpthread": {
"hash": "sha256-VuMpQjxuMsdHsFq0q6QIWSWi88gVF2jNzIfti20Gkbw=",
"version": "539"
},
"mDNSResponder": {
"hash": "sha256-iRqCpPAQDRjgRbRz3s6q2oyzq6xo+w4FTBai79104Zo=",
"version": "2881.0.25"
},
"objc4": {
"hash": "sha256-Nlgr36yLvGkUJIEFQ5w8FAB0r2syEsRTw0KuUShNT8E=",
"version": "950"
},
"ppp": {
"hash": "sha256-FzHZ05o7JxwgTqz0e3D68b/DiLu2x2ErzGMh0U78fLo=",
"version": "1020.1.1"
},
"removefile": {
"hash": "sha256-Z5UD0mk/s80CQB0PZWDzSl2JWXmnVmwUvlNb28+hR3k=",
"version": "84"
},
"xnu": {
"hash": "sha256-Cuf7kPtsn4CPXqyZmxVsJlA5i+Ikryp8ezJyGrvT63c=",
"version": "12377.1.9"
}
}
}

View File

@@ -0,0 +1,533 @@
[
{
"package": "apache",
"headers": [
"apache2"
]
},
{
"package": "apr",
"headers": [
"apr-1"
],
"libraries": [
"libapr-1.*",
"libaprutil-1.*"
]
},
{
"package": "boringssl",
"libraries": [
"libboringssl.*"
]
},
{
"package": "bzip2",
"headers": [
"bzlib.h"
],
"libraries": [
"libbz2.*"
]
},
{
"package": "corecrypto",
"libraries": [
"system/libcorecrypto*"
]
},
{
"package": "Csu",
"libraries": [
"*.o"
]
},
{
"package": "cups",
"headers": [
"cups"
],
"libraries": [
"libcups*"
]
},
{
"package": "curl",
"headers": [
"curl"
],
"libraries": [
"libcurl.*"
]
},
{
"package": "cyrus_sasl",
"headers": [
"sasl"
],
"libraries": [
"libsasl*"
]
},
{
"package": "editline",
"headers": [
"editline.h",
"editline"
],
"libraries": [
"libedit.*",
"libeditline.*"
]
},
{
"package": "html-tidy",
"headers": [
"tidy*"
],
"libraries": [
"libtidy.*"
]
},
{
"package": "hunspell",
"headers": [
"hunspell"
],
"libraries": [
"libhunspell*"
]
},
{
"package": "icu",
"headers": [
"unicode"
],
"libraries": [
"libicucore.*"
]
},
{
"package": "libarchive",
"headers": [
"archive.h",
"archive_entry.h"
],
"libraries": [
"libarchive.*"
]
},
{
"package": "libc++",
"headers": [
"c++",
"cxxabi.h",
"__cxxabi_config.h"
],
"libraries": [
"libc++*"
]
},
{
"package": "ld64",
"libraries": [
"libcodedirectory.*",
"libcodedirectory_static.*"
]
},
{
"package": "expat",
"headers": [
"expat.h",
"expat_config.h",
"expat_external.h"
],
"libraries": [
"libexpat.*"
]
},
{
"package": "libffi",
"headers": [
"ffi*"
],
"libraries": [
"libffi*"
]
},
{
"package": "libgcc",
"libraries": [
"libgcc*"
]
},
{
"package": "libiconv",
"headers": [
"iconv.h",
"libcharset.h",
"localcharset.h"
],
"libraries": [
"libcharset.*",
"libiconv.*",
"i18n"
]
},
{
"package": "libiodbc",
"libraries": [
"libiodbc*"
]
},
{
"package": "libkrb4",
"libraries": [
"libkrb4.*"
]
},
{
"package": "libkrb5",
"headers": [
"com_err.h",
"gssapi",
"gssapi.h",
"gssrpc",
"kadm5",
"kdb.h",
"krad.h",
"krb5",
"krb5.h",
"profile.h",
"verto-module.h",
"verto.h"
],
"libraries": [
"krb5",
"libcom_err.*",
"libgssapi_krb5.*",
"libgssrpc.*",
"libk5crypto.*",
"libkadm5clnt.*",
"libkadm5clnt_mit.*",
"libkadm5srv.*",
"libkadm5srv_mit.*",
"libkdb5.*",
"libkrad.*",
"libkrb5*",
"libkrb5support.*",
"libverto.*"
]
},
{
"package": "libpcap",
"headers": [
"pcap*"
],
"libraries": [
"libpcap.*"
]
},
{
"package": "libresolv",
"headers": [
"arpa/nameser.h",
"arpa/nameser_compat.h",
"dns.h",
"dns_util.h",
"nameser.h",
"resolv.h"
],
"libraries": [
"libresolv.*"
]
},
{
"package": "libstdc++",
"libraries": [
"libstdc++.*"
]
},
{
"package": "libsbuf",
"headers": [
"usbuf.h"
],
"libraries": [
"libsbuf.*"
]
},
{
"package": "libtermcap",
"headers": [
"termcap.h"
],
"libraries": [
"libtermcap.*"
]
},
{
"package": "libutil",
"headers": [
"libutil.h"
],
"libraries": [
"libutil.*",
"libutil1.*"
]
},
{
"package": "libxml2",
"headers": [
"libxml",
"libxml2"
],
"libraries": [
"libxml2.*"
]
},
{
"package": "libxo",
"headers": [
"libxo"
],
"libraries": [
"libxo.*"
]
},
{
"package": "libxslt",
"headers": [
"libexslt",
"libxslt"
],
"libraries": [
"libexslt.*",
"libxslt.*"
]
},
{
"package": "liby",
"libraries": [
"liby.a"
]
},
{
"package": "marisa-trie",
"libraries": [
"libmarisa.*"
]
},
{
"package": "ncurses",
"headers": [
"curses*",
"cursslk.h",
"eti.h",
"etip.h",
"form.h",
"menu.h",
"nc_tparm.h",
"ncurses*",
"panel.h",
"term.h",
"term_entry.h",
"termcap.h",
"tic.h",
"unctrl.h"
],
"libraries": [
"libcurses.*",
"libform.*",
"libformw.*",
"libmenu.*",
"libmenuw.*",
"libncurses.*",
"libncursesw.*",
"libpanel.*",
"libpanelw.*",
"libtinfo.*"
]
},
{
"package": "net-snmp",
"headers": [
"net-snmp"
],
"libraries": [
"libnetsnmp*"
]
},
{
"package": "nghttp",
"libraries": [
"lib*nghttp2.*"
]
},
{
"package": "openblas",
"headers": [
"cblas.h",
"f77blas.h",
"lapack.h",
"lapacke.h",
"lapacke_config.h",
"lapacke_mangling.h",
"lapacke_utils.h",
"openblas_config.h"
],
"libraries": [
"libblas.*",
"libcblas.*",
"libclapack.*",
"libf77lapack.*",
"liblapack.*",
"liblapacke.*",
"libopenblas.*",
"libopenblas.*",
"libopenblasp*"
]
},
{
"package": "openldap",
"libraries": [
"liblber.*",
"liblber_r.*",
"libldap.*",
"libldap_r.*"
]
},
{
"package": "openpam",
"headers": [
"security"
],
"libraries": [
"libpam.*",
"pam_*"
]
},
{
"package": "pcre",
"headers": [
"pcre.h",
"pcreposix.h"
],
"libraries": [
"libpcre.*",
"libpcre2*",
"libpcreposix.*"
]
},
{
"package": "php",
"headers": [
"php"
],
"libraries": [
"php"
]
},
{
"package": "postgresql",
"libraries": [
"libecpg*",
"libpg*",
"libpq*"
]
},
{
"package": "python",
"headers": [
"python*"
],
"frameworks": [
"Python.framework"
],
"libraries": [
"libpython*",
"python*"
]
},
{
"package": "readline",
"headers": [
"readline"
],
"libraries": [
"libhistory.*",
"libreadline.*"
]
},
{
"package": "ruby",
"frameworks": [
"Ruby.framework"
],
"libraries": [
"libruby.*",
"ruby"
]
},
{
"package": "sqlite3",
"headers": [
"sqlite3.h",
"sqlite3ext.h"
],
"libraries": [
"libsqlite3.*"
]
},
{
"package": "swift",
"libraries": [
"swift/shims"
]
},
{
"package": "tcl",
"headers": [
"tcl*",
"tk*"
],
"frameworks": [
"Tcl.framework",
"Tk.framework"
],
"libraries": [
"libtcl*",
"libtk*",
"tclConfig.sh",
"tkConfig.sh"
]
},
{
"package": "xar",
"headers": [
"xar"
],
"libraries": [
"libxar.*"
]
},
{
"package": "xz",
"headers": [
"lzma*"
],
"libraries": [
"liblzma.*"
]
},
{
"package": "zlib",
"headers": [
"zconf.h",
"zlib.h"
],
"libraries": [
"libz.*"
]
}
]

View File

@@ -0,0 +1,26 @@
{
"14": {
"urls": [
"https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg",
"https://web.archive.org/web/20250211001355/https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg"
],
"version": "14.4",
"hash": "sha256-QozDiwY0Czc0g45vPD7G4v4Ra+3DujCJbSads3fJjjM="
},
"15": {
"urls": [
"https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg",
"https://web.archive.org/web/20250530132510/https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg"
],
"version": "15.5",
"hash": "sha256-HBiSJuw1XBUK5R/8Sj65c3rftSEvQl/O9ZZVp/g1Amo="
},
"26": {
"urls": [
"https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg",
"https://web.archive.org/web/20250915230423/https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg"
],
"version": "26.2",
"hash": "sha256-hXRlMieVv0smna5uiWRwq87IWOaPWtAjAldbi+wQXcw="
}
}

110
nix/apple-sdk/package.nix Normal file
View File

@@ -0,0 +1,110 @@
let
sdkVersions = builtins.fromJSON (builtins.readFile ./metadata/versions.json);
in
{ lib
, stdenv
, stdenvNoCC
, substitute
, # Specifies the major version used for the SDK. Uses `hostPlatform.darwinSdkVersion` by default.
darwinSdkMajorVersion ? lib.versions.major stdenv.hostPlatform.darwinSdkVersion
, # Enabling bootstrap disables propagation. Defaults to `false` (meaning to propagate certain packages and `xcrun`)
# except in stage0 of the Darwin stdenv bootstrap.
enableBootstrap ? stdenv.name == "bootstrap-stage0-stdenv-darwin"
, # Required by various phases
callPackage
,
}:
let
sdkInfo =
sdkVersions.${darwinSdkMajorVersion}
or (lib.throw "Unsupported SDK major version: ${darwinSdkMajorVersion}");
sdkVersion = sdkInfo.version;
fetchSDK = callPackage ./common/fetch-sdk.nix { };
phases = lib.composeManyExtensions (
[
(callPackage ./common/add-core-symbolication.nix { })
(callPackage ./common/derivation-options.nix { })
(callPackage ./common/passthru-private-frameworks.nix { inherit sdkVersion; })
(callPackage ./common/passthru-source-release-files.nix { inherit sdkVersion; })
(callPackage ./common/remove-disallowed-packages.nix { })
(callPackage ./common/process-stubs.nix { })
]
# Avoid infinite recursions by not propagating certain packages, so they can themselves build with the SDK.
++ lib.optionals (!enableBootstrap) [
(callPackage ./common/propagate-inputs.nix { })
(callPackage ./common/propagate-xcrun.nix { inherit sdkVersion; })
]
# This has to happen last.
++ [
(callPackage ./common/run-build-phase-hooks.nix { })
]
);
in
stdenvNoCC.mkDerivation (
lib.extends phases (finalAttrs: {
pname = "apple-sdk";
inherit (sdkInfo) version;
src = fetchSDK sdkInfo;
dontConfigure = true;
strictDeps = true;
setupHooks = [
# `role.bash` is copied from `../build-support/setup-hooks/role.bash` due to the requirements not to reference
# paths outside the package when it is in `by-name`. It needs to be kept in sync, but it fortunately does not
# change often. Once `build-support` is available as a package (or some other mechanism), it should be changed
# to whatever that replacement is.
./setup-hooks/role.bash
(substitute {
src = ./setup-hooks/sdk-hook.sh;
substitutions = [
"--subst-var-by"
"sdkVersion"
(lib.escapeShellArgs (lib.splitVersion sdkVersion))
];
})
];
installPhase =
let
sdkName = "MacOSX${lib.versions.majorMinor sdkVersion}.sdk";
sdkMajor = lib.versions.major sdkVersion;
in
''
runHook preInstall
mkdir -p "$sdkpath"
cp -rd . "$sdkpath/${sdkName}"
ln -s "${sdkName}" "$sdkpath/MacOSX${sdkMajor}.sdk"
ln -s "${sdkName}" "$sdkpath/MacOSX.sdk"
# Swift adds these locations to its search paths. Avoid spurious warnings by making sure they exist.
mkdir -p "$platformPath/Developer/Library/Frameworks"
mkdir -p "$platformPath/Developer/Library/PrivateFrameworks"
mkdir -p "$platformPath/Developer/usr/lib"
runHook postInstall
'';
passthru = {
sdkroot = finalAttrs.finalPackage + "/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk";
};
__structuredAttrs = true;
meta = {
description = "Frameworks and libraries required for building packages on Darwin";
homepage = "https://developer.apple.com";
teams = [ lib.teams.darwin ];
platforms = lib.platforms.darwin;
badPlatforms = [ lib.systems.inspect.patterns.is32bit ];
};
})
)

View File

@@ -0,0 +1,48 @@
From 6531da946949a94643e6d8424236174ae64fe0ca Mon Sep 17 00:00:00 2001
From: Randy Eckenrode <randy@largeandhighquality.com>
Date: Sat, 30 Sep 2023 18:02:39 -0400
Subject: [PATCH 1/2] Add function definitions needed to build zlog in
system_cmds
---
CoreSymbolication.h | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/CoreSymbolication.h b/CoreSymbolication.h
index a413860..f3cf63f 100644
--- a/CoreSymbolication.h
+++ b/CoreSymbolication.h
@@ -324,7 +324,9 @@ CSSymbolOwnerEditRelocations
CSSymbolOwnerForeachRegion
CSSymbolOwnerForeachRegionWithName
CSSymbolOwnerForeachSection
-CSSymbolOwnerForeachSegment
+*/
+void CSSymbolOwnerForeachSegment(CSSymbolOwnerRef owner, void (^block)(CSSegmentRef));
+/*
CSSymbolOwnerForeachSourceInfo
CSSymbolOwnerForeachSymbol
*/
@@ -333,7 +335,9 @@ void CSSymbolOwnerForeachSymbolWithName(CSSymbolOwnerRef owner, const char *sna
/*
CSSymbolOwnerGetArchitecture
CSSymbolOwnerGetBaseAddress
-CSSymbolOwnerGetCFUUIDBytes
+*/
+const CFUUIDBytes* CSSymbolOwnerGetCFUUIDBytes(CSSymbolOwnerRef owner);
+/*
CSSymbolOwnerGetCompatibilityVersion
CSSymbolOwnerGetCurrentVersion
CSSymbolOwnerGetDataFlags
@@ -390,7 +394,7 @@ CSSymbolOwnerSetLoadTimestamp
CSSymbolOwnerSetPath
CSSymbolOwnerSetRelocationCount
*/
-CSSymbolOwnerSetTransientUserData(CSSymbolOwnerRef owner, uint32_t gen);
+void CSSymbolOwnerSetTransientUserData(CSSymbolOwnerRef owner, uint32_t gen);
/*
CSSymbolOwnerSetUnloadTimestamp
*/
--
2.44.1

View File

@@ -0,0 +1,45 @@
From ae7ac6a7043dbae8e63d6ce5e63dfaf02b5977fe Mon Sep 17 00:00:00 2001
From: Randy Eckenrode <randy@largeandhighquality.com>
Date: Sat, 30 Sep 2023 18:37:18 -0400
Subject: [PATCH 2/2] Add CF_EXPORT To const symbols
---
CoreSymbolication.h | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)
diff --git a/CoreSymbolication.h b/CoreSymbolication.h
index f3cf63f..4124a54 100644
--- a/CoreSymbolication.h
+++ b/CoreSymbolication.h
@@ -49,6 +49,7 @@
#include <CoreFoundation/CoreFoundation.h>
+#include <CoreFoundation/CFBase.h>
#include <mach/mach.h>
@@ -139,13 +140,13 @@ typedef void (^CSSegmentIterator)(CSSegmentRef segment);
* External symbols
*/
-const char* kCSRegionMachHeaderName;
-const CSDictionaryKeyCallBacks kCSTypeDictionaryKeyCallBacks;
-const CSDictionaryValueCallBacks kCSTypeDictionaryValueCallBacks;
-const CSDictionaryKeyCallBacks kCSTypeDictionaryWeakKeyCallBacks;
-const CSDictionaryValueCallBacks kCSTypeDictionaryWeakValueCallBacks;
-const CSSetCallBacks kCSTypeSetCallBacks;
-const CSSetCallBacks kCSTypeSetWeakCallBacks;
+CF_EXPORT const char* kCSRegionMachHeaderName;
+CF_EXPORT const CSDictionaryKeyCallBacks kCSTypeDictionaryKeyCallBacks;
+CF_EXPORT const CSDictionaryValueCallBacks kCSTypeDictionaryValueCallBacks;
+CF_EXPORT const CSDictionaryKeyCallBacks kCSTypeDictionaryWeakKeyCallBacks;
+CF_EXPORT const CSDictionaryValueCallBacks kCSTypeDictionaryWeakValueCallBacks;
+CF_EXPORT const CSSetCallBacks kCSTypeSetCallBacks;
+CF_EXPORT const CSSetCallBacks kCSTypeSetWeakCallBacks;
/*
--
2.44.1

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p coreutils curl file gzip jq xcbuild yq
set -eu -o pipefail
catalog=${1-}
if [ -z "$catalog" ]; then
echo "usage: get-sdks-from-catalog.sh <catalog>"
echo " <catalog> Apple software update catalog (may be gzipped)" >&2
exit 1
fi
scratch=$(mktemp)
trap 'rm -f -- "$scratch"' EXIT
if [[ "$(file "$catalog")" =~ gzip ]]; then
gzcat "$catalog" >"$scratch"
else
cp --reflink=auto "$catalog" "$scratch"
fi
# Grab all SDK packages from the catalog
filter='.Products[].Packages[] | select(.URL | test(".*CLTools_macOSNMOS_SDK.pkg")) | "\(.URL)|\(.MetadataURL)"'
declare -A package_list
for package in $(plutil -convert json -o - "$scratch" | jq -r "$filter"); do
package_list[${package%%|*}]=${package#*|}
done
truncate --size 0 "$scratch"
for pkg in "${!package_list[@]}"; do
ver=$(curl --silent "${package_list[$pkg]}" | xq -r '."pkg-info"."@version"')
echo "{\"url\": \"$pkg\", \"version\": \"$(cut -d. -f1-3 <<<"$ver")\", \"long_version\": \"$ver\"}" >>"$scratch"
done
jq -r --slurp '
group_by(.version | split(".")[0])
| map(max_by(.version))
| sort_by(.version)[]
| "Package URL: \(.url)\n Xcode Ver: \(.version) (\(.long_version))\n"' "$scratch"

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p coreutils curl git gnutar jq moreutils nix
set -eu -o pipefail
if [ ! -v 2 ]; then
echo "usage: lock-sdk-deps.sh <SDK version> <Packages>" >&2
echo " <SDK version> Decimal-separated version number." >&2
echo " Must correspond to a tag in https://github.com/apple-oss-distributions/distribution-macOS" >&2
echo " <Packages> List of packages from the distributions-macOS repository." >&2
echo " Packages not in the repository at the tag for <SDK version> will be ignored."
exit 1
fi
pkgdir=$(dirname "$(dirname "$(realpath "$0")")")
lockfile=$pkgdir/metadata/apple-oss-lockfile.json
if [ ! -e "$lockfile" ]; then
touch "$lockfile"
fi
workdir=$(mktemp -d)
trap 'rm -rf -- "$workdir"' EXIT
sdkVersion=$1
shift
tag="macos-${sdkVersion//./}"
declare -a packages=("$@")
echo "Locking versions for macOS $sdkVersion using tag '$tag'..."
pushd "$workdir" >/dev/null
git clone --branch "$tag" https://github.com/apple-oss-distributions/distribution-macOS.git &>/dev/null
cd distribution-macOS
for package in "${packages[@]}"; do
# If the tag exists in `release.json`, use that as an optimization to avoid downloading unnecessarily from Github.
packageTag=$(jq -r --arg package "$package" '.projects[] | select(.project == $package) | .tag' release.json)
packageCommit=$(git ls-tree -d HEAD "$package" | awk '{print $3}')
if [ ! -d "$package" ]; then
packageCommit=HEAD
fi
# However, sometimes it doesnt exist. In that case, fall back to cloning the repo and check manually
# which tag corresponds to the commit from the submodule.
if [ -z "$packageTag" ]; then
git clone --no-checkout "https://github.com/apple-oss-distributions/$package.git" ../source &>/dev/null
pushd ../source >/dev/null
packageTag=$(git tag --points-at "$packageCommit")
popd >/dev/null
rm -rf ../source
fi
packageVersion=${packageTag##"$package"-}
curl -OL "https://github.com/apple-oss-distributions/$package/archive/$packageTag.tar.gz" &>/dev/null
tar axf "$packageTag.tar.gz"
packageHash=$(nix --extra-experimental-features nix-command hash path "$package-$packageTag")
pkgsjson="{\"$sdkVersion\": {\"$package\": {\"version\": \"$packageVersion\", \"hash\": \"$packageHash\"}}}"
echo " - Locking $package to version $packageVersion with hash '$packageHash'"
jq --argjson pkg "$pkgsjson" -S '. * $pkg' "$lockfile" | sponge "$lockfile"
done
popd >/dev/null

View File

@@ -0,0 +1,62 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p coreutils jq
set -eu -o pipefail
pkgdir=$(dirname "$(dirname "$(realpath "$0")")")
echo '{}' >"$pkgdir/metadata/apple-oss-lockfile.json"
declare -a versions
readarray -t versions < <(jq -r '.[].version' "$pkgdir/metadata/versions.json")
declare -a packages=(
CarbonHeaders
CommonCrypto
IOAudioFamily
IOFireWireFamily
IOFWDVComponents
IOFireWireAVC
IOFireWireSBP2
IOFireWireSerialBusProtocolTransport
IOGraphics
IOHIDFamily
IONetworkingFamily
IOSerialFamily
IOStorageFamily
IOBDStorageFamily
IOCDStorageFamily
IODVDStorageFamily
IOUSBFamily
IOKitUser
Libc
Libinfo
Libm
Libnotify
Librpcsvc
Libsystem
OpenDirectory
Security
architecture
configd
copyfile
dtrace
dyld
eap8021x
hfs
launchd
libclosure
libdispatch
libmalloc
libplatform
libpthread
mDNSResponder
objc4
ppp
removefile
xnu
)
for version in "${versions[@]}"; do
"$pkgdir/scripts/lock-sdk-deps.sh" "$version" "${packages[@]}"
done

View File

@@ -0,0 +1,6 @@
function enablePrivateFrameworks() {
export NIX_CFLAGS_COMPILE+=" -iframework $DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/PrivateFrameworks"
export NIX_LDFLAGS+=" -F$DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/PrivateFrameworks"
}
preConfigureHooks+=(enablePrivateFrameworks)

View File

@@ -0,0 +1,71 @@
# Since the same derivation can be depended on in multiple ways, we need to
# accumulate *each* role (i.e. host and target platforms relative the depending
# derivation) in which the derivation is used.
#
# The role is intended to be used as part of other variables names like
# - $NIX_SOMETHING${role_post}
function getRole() {
case $1 in
-1)
role_post='_FOR_BUILD'
;;
0)
role_post=''
;;
1)
role_post='_FOR_TARGET'
;;
*)
echo "@name@: used as improper sort of dependency" >&2
return 1
;;
esac
}
# `hostOffset` describes how the host platform of the package is slid relative
# to the depending package. `targetOffset` likewise describes the target
# platform of the package. Both are brought into scope of the setup hook defined
# for dependency whose setup hook is being processed relative to the package
# being built.
function getHostRole() {
getRole "$hostOffset"
}
function getTargetRole() {
getRole "$targetOffset"
}
# `depHostOffset` describes how the host platform of the dependencies are slid
# relative to the depending package. `depTargetOffset` likewise describes the
# target platform of dependenices. Both are brought into scope of the
# environment hook defined for the dependency being applied relative to the
# package being built.
function getHostRoleEnvHook() {
getRole "$depHostOffset"
}
function getTargetRoleEnvHook() {
getRole "$depTargetOffset"
}
# This variant is intended specifically for code-producing tool wrapper scripts
# `NIX_@wrapperName@_TARGET_*_@suffixSalt@` tracks this (needs to be an exported
# env var so can't use fancier data structures).
function getTargetRoleWrapper() {
case $targetOffset in
-1)
export NIX_@wrapperName@_TARGET_BUILD_@suffixSalt@=1
;;
0)
export NIX_@wrapperName@_TARGET_HOST_@suffixSalt@=1
;;
1)
export NIX_@wrapperName@_TARGET_TARGET_@suffixSalt@=1
;;
*)
echo "@name@: used as improper sort of dependency" >&2
return 1
;;
esac
}

View File

@@ -0,0 +1,17 @@
local role_post
getHostRole
local sdkVersionVar=NIX_APPLE_SDK_VERSION${role_post}
local developerDirVar=DEVELOPER_DIR${role_post}
local sdkVersionArr=(@sdkVersion@)
local sdkVersion
sdkVersion=$(printf "%02d%02d%02d" "${sdkVersionArr[0]-0}" "${sdkVersionArr[1]-0}" "${sdkVersionArr[2]-0}")
if [ "$sdkVersion" -gt "${!sdkVersionVar-000000}" ]; then
export "$developerDirVar"='@out@'
export "$sdkVersionVar"="$sdkVersion"
export "SDKROOT${role_post}"="${!developerDirVar}/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk"
fi
unset -v role_post developerDirVar sdkVersion sdkVersionArr sdkVersionVar

View File

@@ -41,7 +41,7 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
version = let v = "0.30.5"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
@@ -86,6 +86,7 @@ let
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeBool "MLX_BUILD_CPU" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"mlx==0.30.5; sys_platform == 'darwin'",
"mlx[cpu]==0.30.5; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -31,8 +31,6 @@ dependencies = [
]
[project.scripts]
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
# dependencies only required for development
@@ -63,7 +61,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
@@ -105,6 +103,7 @@ root = "src"
# supported platforms for this project
[tool.uv]
required-version = ">=0.8.6"
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",

View File

@@ -59,6 +59,22 @@
}
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;
runtimeInputs = [ exoVenv ];
runtimeEnv = {
EXO_DASHBOARD_DIR = self'.packages.dashboard;
EXO_RESOURCES_DIR = inputs.self + /resources;
};
text = ''exec python ${path} "$@"'';
};
mkSimplePythonScript = name: path: pkgs.writeShellApplication {
inherit name;
runtimeInputs = [ pkgs.python313 ];
text = ''exec python ${path} "$@"'';
};
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
@@ -66,28 +82,30 @@
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
# Create wrapper script
makeWrapper ${exoVenv}/bin/exo $out/bin/exo \
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_RESOURCES_DIR ${inputs.self + /resources} \
${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin
{
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
exo-bench = mkPythonScript "exo-bench" (inputs.self + /bench/exo_bench.py);
} // {
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
${pkgs.ruff}/bin/ruff check ${inputs.self}
touch $out
'';
};

View File

@@ -16,6 +16,7 @@ from exo.download.download_utils import (
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
@@ -53,11 +54,10 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
@@ -108,6 +108,13 @@ class DownloadCoordinator:
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads and model_id in self.download_status:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads.pop(model_id).cancel()
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id

View File

@@ -158,6 +158,78 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
async def _build_file_list_from_local_directory(
model_id: ModelId,
recursive: bool = False,
) -> list[FileListEntry] | None:
"""Build a file list from locally existing model files.
We can only figure out the files we need from safetensors index, so
a local directory must contain a *.safetensors.index.json and
safetensors listed there.
"""
model_dir = (await ensure_models_dir()) / model_id.normalize()
if not await aios.path.exists(model_dir):
return None
def _scan() -> list[FileListEntry] | None:
index_files = list(model_dir.glob("**/*.safetensors.index.json"))
if not index_files:
return None
entries_by_path: dict[str, FileListEntry] = {}
if recursive:
for dirpath, _, filenames in os.walk(model_dir):
for filename in filenames:
if filename.endswith(".partial"):
continue
full_path = Path(dirpath) / filename
rel_path = str(full_path.relative_to(model_dir))
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=full_path.stat().st_size,
)
else:
for item in model_dir.iterdir():
if item.is_file() and not item.name.endswith(".partial"):
entries_by_path[item.name] = FileListEntry(
type="file",
path=item.name,
size=item.stat().st_size,
)
# Add expected weight files from index that haven't been downloaded yet
for index_file in index_files:
try:
index_data = ModelSafetensorsIndex.model_validate_json(
index_file.read_text()
)
relative_dir = index_file.parent.relative_to(model_dir)
for filename in set(index_data.weight_map.values()):
rel_path = (
str(relative_dir / filename)
if relative_dir != Path(".")
else filename
)
if rel_path not in entries_by_path:
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=None,
)
except Exception:
continue
return list(entries_by_path.values())
file_list = await asyncio.to_thread(_scan)
if not file_list:
return None
return file_list
_fetched_file_lists_this_session: set[str] = set()
@@ -183,6 +255,14 @@ async def fetch_file_list_with_cache(
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
local_file_list = await _build_file_list_from_local_directory(
model_id, recursive
)
if local_file_list is not None:
logger.warning(
f"No internet and no cached file list for {model_id} - using local file list"
)
return local_file_list
raise FileNotFoundError(
f"No internet connection and no cached file list for {model_id}"
)
@@ -203,10 +283,18 @@ async def fetch_file_list_with_cache(
except Exception as e:
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
f"No internet and no cached file list for {model_id} - using local file list"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
local_file_list = await _build_file_list_from_local_directory(
model_id, recursive
)
if local_file_list is not None:
logger.warning(
f"Failed to fetch file list for {model_id} and no cache exists, "
)
return local_file_list
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
@@ -378,10 +466,14 @@ async def download_file_with_retry(
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
on_connection_lost()
raise e
break
logger.error(
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)

View File

@@ -195,6 +195,10 @@ class ResumableShardDownloader(ShardDownloader):
self, shard: ShardMetadata
) -> RepoDownloadProgress:
_, progress = await download_shard(
shard, self.on_progress_wrapper, skip_download=True
shard,
self.on_progress_wrapper,
skip_download=True,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
return progress

View File

@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.main import Worker
# I marked this as a dataclass as I want trivial constructors.
@dataclass
class Node:
router: Router
@@ -106,6 +105,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
er_send, er_recv = channel[ElectionResult]()
@@ -136,7 +136,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -148,6 +147,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
@@ -188,6 +189,9 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
)
self._tg.start_soon(self.master.run)
elif (

View File

@@ -176,7 +176,7 @@ async def generate_chat_stream(
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str]:
) -> ChatCompletionResponse:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
@@ -223,7 +223,7 @@ async def collect_chat_response(
combined_text = "".join(text_parts)
assert model is not None
yield ChatCompletionResponse(
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
@@ -241,5 +241,4 @@ async def collect_chat_response(
finish_reason=finish_reason,
)
],
).model_dump_json()
return
)

View File

@@ -123,7 +123,6 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
TextGeneration,
)
@@ -530,14 +529,16 @@ class API:
break
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
@@ -632,14 +633,11 @@ class API:
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
@@ -655,7 +653,8 @@ class API:
command = TextGeneration(task_params=task_params)
await self._send(command)
return await self._collect_text_generation_with_stats(command.command_id)
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
@@ -857,11 +856,6 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -943,11 +937,6 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -1331,29 +1320,40 @@ class API:
]
async def run(self):
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
tg.start_soon(self.run_api, shutdown_ev)
try:
await anyio.sleep_forever()
finally:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
self.command_sender.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"
cfg.bind = [f"0.0.0.0:{self.port}"]
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
with anyio.CancelScope(shield=True):
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=lambda: anyio.sleep_forever(),
shutdown_trigger=ev.wait,
)
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -6,6 +6,7 @@ from loguru import logger
from exo.master.placement import (
add_instance_to_placements,
cancel_unnecessary_downloads,
delete_instance,
get_transition_events,
place_instance,
@@ -16,12 +17,12 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
TextGeneration,
@@ -37,7 +38,6 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
TracesMerged,
@@ -68,12 +68,9 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
# Receiving indexed events from the forwarder to be applied to state
# Ideally these would be WorkerForwarderEvents but type system says no :(
local_event_receiver: Receiver[ForwarderEvent],
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -83,6 +80,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -98,16 +96,18 @@ class Master:
async def run(self):
logger.info("Starting Master")
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
try:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -278,8 +278,16 @@ class Master:
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks
self.state.instances, placement
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
@@ -290,7 +298,7 @@ class Master:
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
@@ -300,7 +308,7 @@ class Master:
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
@@ -310,18 +318,6 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -330,9 +326,10 @@ class Master:
]
)
)
self.command_task_mapping.pop(
command.finished_command_id, None
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -15,20 +15,20 @@ from exo.master.placement_utils import (
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CancelDownload,
CreateInstance,
DeleteInstance,
DownloadCommand,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -186,7 +186,6 @@ def delete_instance(
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
@@ -202,18 +201,6 @@ def get_transition_events(
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,
@@ -221,3 +208,29 @@ def get_transition_events(
)
return events
def cancel_unnecessary_downloads(
instances: Mapping[InstanceId, Instance],
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> Sequence[DownloadCommand]:
commands: list[DownloadCommand] = []
currently_downloading = [
(k, v.shard_metadata.model_card.model_id)
for k, vs in download_status.items()
for v in vs
if isinstance(v, (DownloadOngoing))
]
active_models = set(
(
node_id,
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
)
for instance in instances.values()
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
)
for pair in currently_downloading:
if pair not in active_models:
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
return commands

View File

@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
CommandId,
ForwarderCommand,
ForwarderDownloadCommand,
PlaceInstance,
TextGeneration,
)
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
download_command_sender=fcds,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
target_instances = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances, {})
events = get_transition_events(current_instances, target_instances)
# assert
assert len(events) == 0
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances, {})
events = get_transition_events(current_instances, target_instances)
# assert
assert len(events) == 1
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {}
# act
events = get_transition_events(current_instances, target_instances, {})
events = get_transition_events(current_instances, target_instances)
# assert
assert len(events) == 1

View File

@@ -9,6 +9,7 @@ from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
move_on_after,
sleep_forever,
)
from anyio.abc import TaskGroup
@@ -146,18 +147,21 @@ class Router:
async def run(self):
logger.debug("Starting Router")
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
try:
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
with move_on_after(1, shield=True):
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
@@ -166,12 +170,12 @@ class Router:
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
logger.info(f"Subscribed to {topic}")
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
logger.info(f"Unsubscribed from {topic}")
async def _networking_recv(self):
while True:

View File

@@ -86,28 +86,29 @@ class Election:
async def run(self):
logger.info("Starting Election")
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
try:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election finished")
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
finally:
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election shutdown")
async def elect(self, em: ElectionMessage) -> None:
logger.debug(f"Electing: {em}")

View File

@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -76,7 +72,12 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
Command = (
@@ -88,7 +89,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -3,10 +3,11 @@
from collections.abc import Sequence
from mlx_lm.models.cache import (
ArraysCache,
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]

View File

@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -61,11 +60,6 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
runner_id: RunnerId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -93,7 +87,6 @@ Task = (
| LoadModel
| StartWarmup
| TextGeneration
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -194,9 +194,10 @@ class MpReceiver[T]:
raise EndOfStream from None
return item
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
return await to_thread.run_sync(
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.base import (
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
@@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import MiniMaxAttention
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
if TYPE_CHECKING:
from mlx_lm.models.cache import Cache
TimeoutCallback = Callable[[], None]
@@ -503,12 +511,24 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
layer.self_attn.kv_b_proj
)
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
# layer.self_attn.kv_b_proj
# )
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
@@ -624,6 +644,84 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
return y
class WrappedMiniMaxAttention(CustomMlxLayer):
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
super().__init__(layer)
self.group = group
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
cache: "Cache | None" = None,
) -> mx.array:
batch_dim, seq_dim, _ = x.shape
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
queries: mx.array = self._original_layer.q_proj(x)
keys: mx.array = self._original_layer.k_proj(x)
values: mx.array = self._original_layer.v_proj(x)
if getattr(self, "use_qk_norm", False):
q_dim = queries.shape[-1]
k_dim = keys.shape[-1]
n = self.group.size()
qk = mx.concatenate(
[queries, keys], axis=-1
) # (batch_dim, seq_dim, q_dim + k_dim)
qk = mx.distributed.all_gather(
qk, group=self.group
) # (n*batch_dim, seq_dim, q_dim + k_dim)
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
queries = qk[..., :q_dim].reshape(
batch_dim, seq_dim, -1
) # (batch_dim, seq_dim, n * q_dim)
keys = qk[..., q_dim:].reshape(
batch_dim, seq_dim, -1
) # (batch_dim, seq_dim, n * k_dim)
queries = self._original_layer.q_norm(queries)
keys = self._original_layer.k_norm(keys)
# Split back and take this rank's portion
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
queries = queries.reshape(
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
).transpose(0, 2, 1, 3)
keys = keys.reshape(
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
).transpose(0, 2, 1, 3)
values = values.reshape(
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
).transpose(0, 2, 1, 3)
if cache is not None:
queries = self._original_layer.rope(queries, offset=cache.offset)
keys = self._original_layer.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self._original_layer.rope(queries)
keys = self._original_layer.rope(keys)
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self._original_layer.scale, # type: ignore
mask=mask,
)
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
return self._original_layer.o_proj(output)
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -632,7 +730,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
rank = self.group.rank()
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
@@ -643,18 +740,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Shard qk_norm weights if present (must match sharded head count)
if getattr(layer.self_attn, "use_qk_norm", False):
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
@@ -679,18 +769,95 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
layer.self_attn.n_kv_heads //= self.N
if isinstance(layer, Qwen3DecoderLayer):
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
layer.self_attn.k_proj = self.all_to_sharded_linear(
layer.self_attn.k_proj
)
layer.self_attn.v_proj = self.all_to_sharded_linear(
layer.self_attn.v_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
else:
assert isinstance(layer, Qwen3NextDecoderLayer)
if hasattr(layer, "linear_attn"):
linear_attn = layer.linear_attn
linear_attn.in_proj_qkvz = self.all_to_sharded_linear(
linear_attn.in_proj_qkvz
)
linear_attn.in_proj_ba = self.all_to_sharded_linear(
linear_attn.in_proj_ba
)
linear_attn.out_proj = self.sharded_to_all_linear(
linear_attn.out_proj
)
# Shard conv1d: depthwise conv with non-contiguous channel slicing.
# Channel layout is [q(key_dim), k(key_dim), v(value_dim)].
# Each rank takes its head-slice from each of the three sections.
rank = self.group.rank()
key_dim = linear_attn.key_dim
value_dim = linear_attn.value_dim
key_dim_shard = key_dim // self.N
value_dim_shard = value_dim // self.N
q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)
k_idx = mx.arange(
key_dim + rank * key_dim_shard,
key_dim + (rank + 1) * key_dim_shard,
)
v_idx = mx.arange(
2 * key_dim + rank * value_dim_shard,
2 * key_dim + (rank + 1) * value_dim_shard,
)
conv_indices = mx.concatenate([q_idx, k_idx, v_idx])
linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]
new_conv_dim = key_dim_shard * 2 + value_dim_shard
linear_attn.conv1d.groups = new_conv_dim
num_v_shard = linear_attn.num_v_heads // self.N
v_start = rank * num_v_shard
v_end = v_start + num_v_shard
linear_attn.A_log = linear_attn.A_log[v_start:v_end]
linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]
linear_attn.num_k_heads //= self.N
linear_attn.num_v_heads //= self.N
linear_attn.key_dim = (
linear_attn.head_k_dim * linear_attn.num_k_heads
)
linear_attn.value_dim = (
linear_attn.head_v_dim * linear_attn.num_v_heads
)
linear_attn.conv_dim = (
linear_attn.key_dim * 2 + linear_attn.value_dim
)
else:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
layer.self_attn.k_proj = self.all_to_sharded_linear(
layer.self_attn.k_proj
)
layer.self_attn.v_proj = self.all_to_sharded_linear(
layer.self_attn.v_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
@@ -700,6 +867,14 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
self.all_to_sharded_linear_in_place(
layer.mlp.shared_expert.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_expert.down_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group

View File

@@ -1,16 +1,14 @@
import os
from copy import deepcopy
from typing import Any, cast
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
ArraysCache,
KVCache,
QuantizedKVCache,
RotatingKVCache,
trim_prompt_cache,
)
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
@@ -26,51 +24,119 @@ _MEMORY_THRESHOLD = float(
)
class KVPrefixCache:
class CacheSnapshot:
"""Snapshot of states at a known token position."""
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
self, states: list[RotatingKVCache | ArraysCache | None], token_count: int
):
self.states = states
self.token_count = token_count
def snapshot_ssm_states(cache: KVCacheType) -> CacheSnapshot:
states: list[ArraysCache | RotatingKVCache | None] = []
for c in cache:
if isinstance(c, (ArraysCache, RotatingKVCache)):
states.append(deepcopy(c))
else:
states.append(None)
token_count = cache_length(cache)
return CacheSnapshot(states=states, token_count=token_count)
def _find_nearest_snapshot(
snapshots: list[CacheSnapshot],
target_token_count: int,
) -> CacheSnapshot | None:
best: CacheSnapshot | None = None
for snap in snapshots:
if snap.token_count <= target_token_count and (
best is None or snap.token_count > best.token_count
):
best = snap
return best
def has_non_kv_caches(cache: KVCacheType) -> bool:
"""Check if a cache contains any ArraysCache (SSM) entries."""
return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)
class KVPrefixCache:
def __init__(self, group: mx.distributed.Group | None = None):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._snapshots: list[list[CacheSnapshot] | None] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self._snapshots.clear()
self._last_used.clear()
def add_kv_cache(self, prompt: str, cache: KVCacheType):
def add_kv_cache(
self,
prompt_tokens: mx.array,
cache: KVCacheType,
ssm_snapshots: list[CacheSnapshot] | None = None,
):
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.prompts.append(prompt_tokens)
self.caches.append(deepcopy(cache))
self._snapshots.append(ssm_snapshots)
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
def update_kv_cache(
self,
index: int,
prompt: str,
prompt_tokens: mx.array,
cache: KVCacheType,
snapshots: list[CacheSnapshot] | None,
restore_pos: int,
):
"""Update an existing cache entry in-place."""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts[index] = tokenized_prompt
old_snapshots = self._snapshots[index]
merged: list[CacheSnapshot] = []
if old_snapshots:
merged = [s for s in old_snapshots if s.token_count <= restore_pos]
if snapshots:
merged.extend(snapshots)
self.prompts[index] = prompt_tokens
self.caches[index] = deepcopy(cache)
self._snapshots[index] = merged or None
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
def _get_snapshot(
self, entry_index: int, target_token_count: int
) -> tuple[int, CacheSnapshot | None]:
if not has_non_kv_caches(self.caches[entry_index]):
return target_token_count, None
snapshots = self._snapshots[entry_index]
if not snapshots:
return 0, None
snap = _find_nearest_snapshot(snapshots, target_token_count)
if snap is not None:
return snap.token_count, snap
return 0, None
def get_kv_cache(
self,
model: Model,
prompt: str,
prompt_tokens: mx.array,
) -> tuple[KVCacheType, mx.array, int | None]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
@@ -79,76 +145,71 @@ class KVPrefixCache:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
nearest SSM snapshot position at or before the match point for correctness.
Same for rotating KV Cache.
"""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
max_length = len(tokenized_prompt)
max_length = len(prompt_tokens)
best_snapshot_index, best_snapshot_length = None, 0
best_index: int | None = None
best_length = 0
is_exact = False
# Find best cache
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(tokenized_prompt, cached_prompt)
length = get_prefix_length(prompt_tokens, cached_prompt)
if length > best_length:
best_index, best_length = i, length
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
self._access_counter += 1
self._last_used[i] = self._access_counter
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:], i
is_exact = True
best_index, best_length = i, length
break
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_index is None:
return make_kv_cache(model), prompt_tokens, None
if best_snapshot_index is not None:
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
# For exact match: trim to max_length-1 so remaining has the last token
# For partial match: trim to best_length, remaining has suffix to prefill
# This ensures stream_generate always has at least one token to start with
target = (max_length - 1) if is_exact else best_length
restore_pos, restore_snap = self._get_snapshot(best_index, target)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# No usable snapshot — need fresh cache
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
return make_kv_cache(model), prompt_tokens, None
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
prompt_cache = deepcopy(self.caches[best_index])
cached_length = cache_length(self.caches[best_index])
tokens_to_trim = cached_length - restore_pos
if tokens_to_trim > 0:
trim_cache(prompt_cache, tokens_to_trim, restore_snap)
# Reset cache offset to match trimmed length
for c in prompt_cache:
if hasattr(c, "offset"):
c.offset = restore_pos
self._access_counter += 1
self._last_used[best_snapshot_index] = self._access_counter
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens, best_snapshot_index
self._access_counter += 1
self._last_used[best_index] = self._access_counter
remaining = prompt_tokens[restore_pos:]
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
return prompt_cache, tokenized_prompt, None
return prompt_cache, remaining, best_index
def _evict_if_needed(self):
"""Evict least recently used entries while memory usage is high."""
if len(self.caches) == 0:
return
# Evict LRU entries until below threshold or only one entry left
# Evict LRU entries until below threshold
while (
len(self.caches) > 1
len(self.caches) > 0
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._snapshots.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
@@ -169,6 +230,21 @@ class KVPrefixCache:
return max_pressure
def trim_cache(
cache: KVCacheType,
num_tokens: int,
snapshot: CacheSnapshot | None = None,
) -> None:
for i, c in enumerate(cache):
if isinstance(c, (ArraysCache, RotatingKVCache)):
if snapshot is not None and snapshot.states[i] is not None:
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
else:
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
else:
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
@@ -177,14 +253,14 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(prompt_tokens)
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
return max(getattr(c, "offset", 0) for c in cache)
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
@@ -215,7 +291,7 @@ def make_kv_cache(
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
if hasattr(model, "make_cache"):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore

View File

@@ -1,9 +1,10 @@
import time
from typing import Any, Callable, Generator, cast, get_args
from copy import deepcopy
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -23,7 +24,14 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
from exo.worker.engines.mlx.cache import (
CacheSnapshot,
KVPrefixCache,
encode_prompt,
has_non_kv_caches,
make_kv_cache,
snapshot_ssm_states,
)
from exo.worker.engines.mlx.constants import (
DEFAULT_TOP_LOGPROBS,
KV_BITS,
@@ -32,6 +40,8 @@ from exo.worker.engines.mlx.constants import (
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
fix_unmatched_think_end_tokens,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -46,7 +56,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> tuple[float, int]:
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
@@ -57,17 +67,21 @@ def prefill(
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0, 0
return 0.0, 0, []
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
has_ssm = has_non_kv_caches(cache)
snapshots: list[CacheSnapshot] = []
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
elapsed = time.perf_counter() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
if has_ssm:
snapshots.append(snapshot_ssm_states(cache))
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
@@ -84,7 +98,18 @@ def prefill(
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cast(list[Any], cache), 1)
# stream_generate added 1 extra generated token to the cache, so we should trim it.
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
for i, c in enumerate(cache):
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
assert pre_gen is not None
if pre_gen.states[i] is not None:
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
else:
assert not isinstance(c, (ArraysCache, RotatingKVCache))
c.trim(2) # pyright: ignore[reportUnknownMemberType]
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
@@ -92,12 +117,14 @@ def prefill(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec, num_tokens
# Exclude the last snapshot
return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else []
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -116,7 +143,7 @@ def warmup_inference(
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
sampler = make_sampler(temp=0.0)
logger.info("Generating warmup tokens")
for _r in stream_generate(
@@ -135,6 +162,8 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
mx_barrier(group)
return tokens_generated
@@ -216,11 +245,17 @@ def mlx_generate(
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
group: mx.distributed.Group | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
if task.seed is not None:
mx.random.seed(task.seed)
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
seed = task.seed or 42
mx.random.seed(seed)
# Encode prompt once at the top and fix unmatched think tags
all_prompt_tokens = encode_prompt(tokenizer, prompt)
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
# Do not use the prefix cache if we are trying to do benchmarks.
is_bench = task.bench
@@ -232,13 +267,16 @@ def mlx_generate(
matched_index: int | None = None
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
prompt_tokens = all_prompt_tokens
else:
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
model, all_prompt_tokens
)
all_prompt_tokens = encode_prompt(tokenizer, prompt)
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
if prefix_hit_length > 0:
logger.info(
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -261,12 +299,17 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model,
tokenizer,
sampler,
prompt_tokens[:-1],
caches,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
last_token = prompt_tokens[-2:]
max_tokens = task.max_output_tokens or MAX_TOKENS
accumulated_text = ""
@@ -294,7 +337,6 @@ def mlx_generate(
start=1,
):
generated_text_parts.append(out.text)
logger.info(out.text)
accumulated_text += out.text
if think_start is not None and out.text == think_start:
@@ -362,16 +404,6 @@ def mlx_generate(
selected_token=out.token,
)
yield GenerationResponse(
text=text,
token=out.token,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
usage=usage,
)
if is_done:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
@@ -385,14 +417,42 @@ def mlx_generate(
f"{generation_tps:.1f} tok/s"
)
if kv_prefix_cache is not None:
full_prompt = prompt + "".join(generated_text_parts)
generated_tokens_array = mx.array(
tokenizer.encode(
"".join(generated_text_parts), add_special_tokens=False
)
)
full_prompt_tokens = mx.concatenate(
[all_prompt_tokens, generated_tokens_array]
)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
kv_prefix_cache.update_kv_cache(
matched_index,
full_prompt_tokens,
caches,
cache_snapshots,
restore_pos=prefix_hit_length,
)
else:
kv_prefix_cache.add_kv_cache(full_prompt, caches)
kv_prefix_cache.add_kv_cache(
full_prompt_tokens, caches, cache_snapshots
)
yield GenerationResponse(
text=text,
token=out.token,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
usage=usage,
)
if is_done:
mx_barrier(group)
break
# Limit accumulated_text to what's needed for stop sequence detection

View File

@@ -67,6 +67,8 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -84,6 +86,30 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -464,6 +490,30 @@ def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> b
return think_token is not None and prompt.rstrip().endswith(think_token)
def fix_unmatched_think_end_tokens(
tokens: mx.array, tokenizer: TokenizerWrapper
) -> mx.array:
if not tokenizer.has_thinking:
return tokens
assert tokenizer.think_start_id
assert tokenizer.think_end_id
think_start_id: int = tokenizer.think_start_id
think_end_id: int = tokenizer.think_end_id
token_list: list[int] = cast(list[int], tokens.tolist())
result: list[int] = []
depth = 0
for token in token_list:
if token == think_start_id:
depth += 1
elif token == think_end_id:
if depth == 0:
result.append(think_start_id)
else:
depth -= 1
result.append(token)
return mx.array(result)
class NullKVCache(KVCache):
"""
A KVCache that pretends to exist but holds zero tokens.
@@ -536,23 +586,3 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -32,7 +32,6 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -99,22 +98,23 @@ class Worker:
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
async with create_task_group() as tg:
try:
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
tg.start_soon(runner.shutdown)
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
@@ -218,22 +218,15 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await runner.start_task(task)
await self.runners.pop(runner_id).start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
await runner.shutdown()
case CancelTask(
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -271,18 +264,18 @@ class Worker:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self._start_runner_task(modified_task)
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self._start_runner_task(task)
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
await self.runners[
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
@@ -321,6 +314,8 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ConnectToGroup,
CreateRunner,
DownloadModel,
@@ -54,14 +53,13 @@ def plan(
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
_cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances)
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -272,7 +270,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -286,7 +284,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len(input_chunk_buffer.get(cmd_id, {}))
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -294,33 +292,16 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner_id, runner in runners.items():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id,
cancelled_task_id=task.task_id,
runner_id=runner_id,
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,7 +15,6 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -39,7 +38,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -1,6 +1,5 @@
import base64
import json
import math
import time
from collections.abc import Generator
from functools import cache
@@ -88,7 +87,6 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -113,7 +111,6 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -128,15 +125,11 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -149,7 +142,6 @@ def main(
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -195,19 +187,19 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
inference_model, tokenizer = load_mlx_items(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer, group)
kv_prefix_cache = KVPrefixCache(group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
image_model = initialize_image_model(bound_instance)
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -215,6 +207,8 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -226,30 +220,16 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert inference_model
assert not isinstance(model, DistributedImageModel)
assert tokenizer
t = time.perf_counter()
toks = warmup_inference(
model=inference_model,
model=model,
tokenizer=tokenizer,
group=group,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / (time.perf_counter() - t)), 100
)
if group is not None:
check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]), group=group
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
@@ -257,8 +237,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert image_model
image = warmup_image_generator(model=image_model)
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -278,9 +258,9 @@ def main(
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert inference_model
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
assert check_for_cancel_every
try:
_check_for_debug_prompts(task_params)
@@ -290,11 +270,12 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=inference_model,
model=model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
group=group,
)
# For other thinking models (GLM, etc.), check if we need to
@@ -314,11 +295,11 @@ def main(
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(inference_model, GptOssModel):
elif isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -331,18 +312,7 @@ def main(
)
completion_tokens = 0
tokens_since_last_cancel_check = 0
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
tokens_since_last_cancel_check = 0
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
completion_tokens += 1
@@ -414,7 +384,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
assert isinstance(model, DistributedImageModel)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -427,9 +397,7 @@ def main(
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
for response in generate_image(model=model, task=task_params):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
@@ -479,7 +447,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
assert isinstance(model, DistributedImageModel)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -492,9 +460,7 @@ def main(
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
for response in generate_image(model=model, task=task_params):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
@@ -560,7 +526,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
del model, tokenizer, group
mx.clear_cache()
import gc
@@ -663,7 +629,7 @@ def parse_thinking_models(
yield response.model_copy(
update={
"text": tokenizer.think_start,
"token": tokenizer.think_start_id, # type: ignore
"token": tokenizer.think_start_id,
}
)
yield response

View File

@@ -8,10 +8,8 @@ import anyio
from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
to_thread,
)
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import (
@@ -49,12 +47,9 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -65,8 +60,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -74,7 +69,6 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -89,7 +83,6 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -97,42 +90,35 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with self._tg as tg:
tg.start_soon(self._forward_events)
await self._forward_events()
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.close()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 10)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def shutdown(self):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
self._tg.cancel_scope.cancel()
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def start_task(self, task: Task):
if task.task_id in self.pending:
@@ -155,13 +141,6 @@ class RunnerSupervisor:
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
async def _forward_events(self):
with self._ev_recv as events:
try:
@@ -226,4 +205,4 @@ class RunnerSupervisor:
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
await self.shutdown()
self.shutdown()

View File

@@ -88,12 +88,12 @@ class TestKVPrefix:
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache = KVPrefixCache()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache = KVPrefixCache()
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
@@ -101,7 +101,7 @@ class TestKVPrefix:
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache = KVPrefixCache()
cache.clear()
assert len(cache.prompts) == 0
@@ -142,10 +142,12 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert cache_length(cache) == len(tokens)
# Cache should now hold the prompt tokens minus one
assert cache_length(cache) == len(tokens) - 1
# Snapshots should be available for models with non-KV caches
assert len(snapshots) > 0
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
@@ -159,10 +161,10 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
assert len(kv_prefix_cache.prompts) == 1
stored_length = cache_length(kv_prefix_cache.caches[0])
@@ -170,7 +172,7 @@ class TestKVPrefixCacheWithModel:
# Retrieve with same prompt: exact match
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
model, tokens
)
assert matched_index == 0
@@ -191,10 +193,12 @@ class TestKVPrefixCacheWithModel:
short_tokens = encode_prompt(tokenizer, short_prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache
)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(short_prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
# Query with longer prompt that shares the chat template prefix
long_task = TextGenerationTaskParams(
@@ -212,13 +216,12 @@ class TestKVPrefixCacheWithModel:
)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, long_prompt
model, long_tokens
)
assert matched_index == 0
# remaining_tokens should be the suffix after the shared prefix
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
# remaining_tokens covers from snapshot restore position to end
assert len(remaining_tokens) >= len(long_tokens) - expected_prefix
def test_stored_cache_not_mutated_after_get_and_generation(
self, model_and_tokenizer
@@ -235,15 +238,15 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)
assert matched_index == 0
# Simulate generation: feed many additional tokens through the cache
@@ -273,15 +276,15 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
stored_length = cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
@@ -298,7 +301,7 @@ class TestKVPrefixCacheWithModel:
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello")],
@@ -328,7 +331,7 @@ class TestKVPrefixCacheWithModel:
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Reuse test")],
@@ -352,20 +355,20 @@ class TestKVPrefixCacheWithModel:
# Second call should find a prefix match (the stored cache contains
# prompt + generated tokens, which shares the prompt prefix)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
model, prompt_tokens
)
# The stored cache is longer than the prompt (it includes generated tokens),
# so this is a prefix match where our prompt is fully contained
assert matched_index == 0
# Exact match: remaining_tokens is just the last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
# Exact match: remaining_tokens is just the last token and the one before
assert len(remaining_tokens) == 2
assert mx.array_equal(remaining_tokens, prompt_tokens[-2:])
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
@@ -444,7 +447,7 @@ class TestKVPrefixCacheWithModel:
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Immutable test")],
@@ -481,7 +484,7 @@ class TestKVPrefixCacheWithModel:
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache()
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
@@ -495,7 +498,7 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -505,19 +508,10 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache._last_used[2] = 100.0
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
# Simulate memory pressure: active memory exceeds threshold
fake_limit = 1000
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
with (
patch(
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
return_value=fake_active,
),
patch(
"exo.worker.engines.mlx.cache.mx.metal.device_info",
return_value={"max_recommended_working_set_size": fake_limit},
),
# Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)
with patch(
"exo.worker.engines.mlx.cache.get_memory_used_percentage",
return_value=0.95,
):
# Trigger eviction by adding a new entry
task = TextGenerationTaskParams(
@@ -529,14 +523,11 @@ class TestKVPrefixCacheWithModel:
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
# Since fake_active stays above threshold after each eviction (we don't change it),
# all old entries get evicted, leaving only the newly added one
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)
assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)

View File

@@ -34,6 +34,7 @@ TOKENIZER_FILE_PATTERNS = [
"added_tokens.json",
"tokenizer.model",
"tokenization_*.py", # Custom tokenizer implementations
"tool_declaration_ts.py", # Dependency of tokenization_kimi.py
]

View File

@@ -2,7 +2,6 @@
from collections.abc import Iterable
from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
@@ -20,7 +19,6 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
@@ -115,8 +113,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
@@ -167,7 +163,6 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
event_sender = EventCollector()
with task_sender:
@@ -179,7 +174,7 @@ def _run(tasks: Iterable[Task]):
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
return event_sender.events

53
tests/auto_bench.sh Executable file
View File

@@ -0,0 +1,53 @@
#!/usr/bin/env bash
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
sleep 1
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
done
wait
for host; do
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
done
for host; do
echo "Waiting for $host..." 1>&2
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
echo "Waiting 30s for cluster setup" 1>&2
sleep 30
echo "EXO loaded" 1>&2
bench_runner="${hosts[0]}"
mkdir -p "./bench/$commit"
nix run .#exo-get-all-models-on-cluster -- "$bench_runner" | while IFS= read -r model; do
echo "running bench for $model" 1>&2
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$bench_runner@$bench_runner" "/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-bench -- --model $model --pp 128 4096 --tg 128 --stdout --skip-tensor-ring" >>"./bench/$commit/${model//\//--}.json"
echo
done

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# pyright: reportAny=false
import json
import subprocess
import sys
from typing import Any, cast
from urllib.request import urlopen
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = next(
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None
) or sys.exit(f"{h} not found in tailscale")
with urlopen(f"http://{ip}:52415/state", timeout=5) as r:
data = json.loads(r.read()).get("downloads", {})
def mid(x: dict[str, Any]) -> str | None:
for k in (
"DownloadCompleted",
"shardMetadata",
"PipelineShardMetadata",
"modelCard",
"modelId",
):
x = x.get(k, {})
return cast(str | None, x if x != {} else None)
common = set[str].intersection(
*[{m for d in nid if (m := mid(d))} for nid in data.values()]
)
for c in common:
print(c)

View File

@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
@@ -34,21 +34,13 @@ reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
wait

2173
uv.lock generated
View File

File diff suppressed because it is too large Load Diff