mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-06 12:11:22 -05:00
Compare commits
9 Commits
slow-down-
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2f47802c8 | ||
|
|
a067fd0753 | ||
|
|
333119ec2f | ||
|
|
79661ea3a3 | ||
|
|
196f4fb35f | ||
|
|
08a4ca59c6 | ||
|
|
45ed3903f3 | ||
|
|
6657d8600f | ||
|
|
bf06580e0e |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -32,6 +32,3 @@ dashboard/.svelte-kit/
|
||||
# host config snapshots
|
||||
hosts_*.json
|
||||
.swp
|
||||
|
||||
# bench files
|
||||
bench/**/*.json
|
||||
|
||||
@@ -1139,7 +1139,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`flatten`."""
|
||||
|
||||
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
|
||||
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`reshape` but the shape can be passed either as a
|
||||
:obj:`tuple` or as separate arguments.
|
||||
@@ -1222,7 +1222,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`swapaxes`."""
|
||||
|
||||
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
|
||||
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`transpose` but the axes can be passed either as
|
||||
a tuple or as separate arguments.
|
||||
|
||||
@@ -30,9 +30,6 @@ class Conv1d(Module):
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
weight: mx.array
|
||||
groups: int
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
||||
@@ -11,10 +11,7 @@ import mlx.core as mx
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
offset: int
|
||||
def update_and_fetch(
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
@@ -90,7 +87,6 @@ def create_attention_mask(
|
||||
class _BaseCache(Cache):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.mla import MultiLinear
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
@@ -61,10 +60,7 @@ class DeepseekV3Attention(nn.Module):
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
# kv_b_proj: nn.Linear
|
||||
embed_q: MultiLinear
|
||||
unembed_out: MultiLinear
|
||||
|
||||
kv_b_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""Type stubs for mlx_lm.models.qwen3_next"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
class Qwen3NextMLP(nn.Module):
|
||||
gate_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextGatedDeltaNet(nn.Module):
|
||||
hidden_size: int
|
||||
num_v_heads: int
|
||||
num_k_heads: int
|
||||
head_k_dim: int
|
||||
head_v_dim: int
|
||||
key_dim: int
|
||||
value_dim: int
|
||||
conv_kernel_size: int
|
||||
conv_dim: int
|
||||
conv1d: nn.Conv1d
|
||||
in_proj_qkvz: nn.Linear
|
||||
in_proj_ba: nn.Linear
|
||||
dt_bias: mx.array
|
||||
A_log: mx.array
|
||||
out_proj: nn.Linear
|
||||
|
||||
def __init__(self, config: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
norm_topk_prob: bool
|
||||
num_experts: int
|
||||
top_k: int
|
||||
gate: nn.Linear
|
||||
switch_mlp: SwitchGLU
|
||||
shared_expert: Qwen3NextMLP
|
||||
shared_expert_gate: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextDecoderLayer(nn.Module):
|
||||
is_linear: bool
|
||||
linear_attn: Qwen3NextGatedDeltaNet
|
||||
self_attn: Qwen3NextAttention
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock
|
||||
|
||||
def __init__(self, args: Any, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextModel(nn.Module):
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Qwen3NextDecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: Qwen3NextModel
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[Qwen3NextDecoderLayer]: ...
|
||||
@@ -113,10 +113,6 @@ class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
think_start: str | None
|
||||
think_end: str | None
|
||||
think_start_id: int | None
|
||||
think_end_id: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
@@ -431,12 +431,7 @@ def main() -> int:
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
@@ -455,7 +450,6 @@ def main() -> int:
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument("--stdout", action="store_true", help="Write results to stdout")
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
@@ -539,16 +533,6 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
@@ -668,9 +652,7 @@ def main() -> int:
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.stdout:
|
||||
json.dump(all_rows, sys.stdout, indent=2, ensure_ascii=False)
|
||||
elif args.json_out:
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
isLoading,
|
||||
stopGeneration,
|
||||
sendMessage,
|
||||
generateImage,
|
||||
editImage,
|
||||
@@ -605,86 +606,92 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => stopGeneration()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
<rect x="4" y="4" width="16" height="16" rx="2" />
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
<span class="hidden sm:inline">STOP</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -470,6 +470,7 @@ class AppStore {
|
||||
messages = $state<Message[]>([]);
|
||||
currentResponse = $state("");
|
||||
isLoading = $state(false);
|
||||
private currentAbortController: AbortController | null = null;
|
||||
|
||||
// Performance metrics
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
@@ -1738,9 +1739,11 @@ class AppStore {
|
||||
return;
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -1854,6 +1857,7 @@ class AppStore {
|
||||
"Unknown error",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -1985,6 +1989,10 @@ class AppStore {
|
||||
assistantMessageId: string,
|
||||
errorPrefix = "Failed to get response",
|
||||
): void {
|
||||
// Don't show error for user-initiated abort (stop button)
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
@@ -2026,6 +2034,17 @@ class AppStore {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the current generation by aborting the HTTP connection.
|
||||
* This triggers backend cancellation via the mechanism in PR #1276.
|
||||
*/
|
||||
stopGeneration() {
|
||||
if (this.currentAbortController) {
|
||||
this.currentAbortController.abort();
|
||||
this.currentAbortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to the LLM and stream the response
|
||||
*/
|
||||
@@ -2173,11 +2192,13 @@ class AppStore {
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = 0;
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -2325,6 +2346,7 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2424,11 +2446,13 @@ class AppStore {
|
||||
};
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/images/generations", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
@@ -2577,6 +2601,7 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2705,8 +2730,10 @@ class AppStore {
|
||||
);
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
signal: this.currentAbortController.signal,
|
||||
body: formData,
|
||||
});
|
||||
|
||||
@@ -2816,6 +2843,7 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2944,6 +2972,7 @@ export const hasStartedChat = () => appStore.hasStartedChat;
|
||||
export const messages = () => appStore.messages;
|
||||
export const currentResponse = () => appStore.currentResponse;
|
||||
export const isLoading = () => appStore.isLoading;
|
||||
export const stopGeneration = () => appStore.stopGeneration();
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
|
||||
@@ -118,10 +118,9 @@
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
inherit (self'.packages) metal-toolchain;
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
default = self'.packages.exo;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
2
justfile
2
justfile
@@ -20,7 +20,7 @@ sync-clean:
|
||||
|
||||
rust-rebuild:
|
||||
cargo run --bin stub_gen
|
||||
uv sync --reinstall-package exo_pyo3_bindings
|
||||
just sync-clean
|
||||
|
||||
build-dashboard:
|
||||
#!/usr/bin/env bash
|
||||
|
||||
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.5"; in
|
||||
version = let v = "0.30.4"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.5; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.5; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -31,6 +31,8 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
|
||||
# dependencies only required for development
|
||||
@@ -61,7 +63,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
@@ -103,7 +105,6 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
required-version = ">=0.8.6"
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
|
||||
@@ -59,22 +59,6 @@
|
||||
}
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ exoVenv ];
|
||||
runtimeEnv = {
|
||||
EXO_DASHBOARD_DIR = self'.packages.dashboard;
|
||||
EXO_RESOURCES_DIR = inputs.self + /resources;
|
||||
};
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
|
||||
mkSimplePythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ pkgs.python313 ];
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
@@ -82,30 +66,28 @@
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper script
|
||||
makeWrapper ${exoVenv}/bin/exo $out/bin/exo \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + /resources} \
|
||||
${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
|
||||
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
done
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin
|
||||
{
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
exo-bench = mkPythonScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
} // {
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -16,7 +16,6 @@ from exo.download.download_utils import (
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
@@ -54,10 +53,11 @@ class DownloadCoordinator:
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
@@ -108,13 +108,6 @@ class DownloadCoordinator:
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
case CancelDownload(model_id=model_id):
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads and model_id in self.download_status:
|
||||
logger.info(f"Cancelling download for {model_id}")
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
@@ -158,78 +158,6 @@ async def seed_models(seed_dir: str | Path):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def _build_file_list_from_local_directory(
|
||||
model_id: ModelId,
|
||||
recursive: bool = False,
|
||||
) -> list[FileListEntry] | None:
|
||||
"""Build a file list from locally existing model files.
|
||||
|
||||
We can only figure out the files we need from safetensors index, so
|
||||
a local directory must contain a *.safetensors.index.json and
|
||||
safetensors listed there.
|
||||
"""
|
||||
model_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return None
|
||||
|
||||
def _scan() -> list[FileListEntry] | None:
|
||||
index_files = list(model_dir.glob("**/*.safetensors.index.json"))
|
||||
if not index_files:
|
||||
return None
|
||||
|
||||
entries_by_path: dict[str, FileListEntry] = {}
|
||||
|
||||
if recursive:
|
||||
for dirpath, _, filenames in os.walk(model_dir):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".partial"):
|
||||
continue
|
||||
full_path = Path(dirpath) / filename
|
||||
rel_path = str(full_path.relative_to(model_dir))
|
||||
entries_by_path[rel_path] = FileListEntry(
|
||||
type="file",
|
||||
path=rel_path,
|
||||
size=full_path.stat().st_size,
|
||||
)
|
||||
else:
|
||||
for item in model_dir.iterdir():
|
||||
if item.is_file() and not item.name.endswith(".partial"):
|
||||
entries_by_path[item.name] = FileListEntry(
|
||||
type="file",
|
||||
path=item.name,
|
||||
size=item.stat().st_size,
|
||||
)
|
||||
|
||||
# Add expected weight files from index that haven't been downloaded yet
|
||||
for index_file in index_files:
|
||||
try:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(
|
||||
index_file.read_text()
|
||||
)
|
||||
relative_dir = index_file.parent.relative_to(model_dir)
|
||||
for filename in set(index_data.weight_map.values()):
|
||||
rel_path = (
|
||||
str(relative_dir / filename)
|
||||
if relative_dir != Path(".")
|
||||
else filename
|
||||
)
|
||||
if rel_path not in entries_by_path:
|
||||
entries_by_path[rel_path] = FileListEntry(
|
||||
type="file",
|
||||
path=rel_path,
|
||||
size=None,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return list(entries_by_path.values())
|
||||
|
||||
file_list = await asyncio.to_thread(_scan)
|
||||
if not file_list:
|
||||
return None
|
||||
return file_list
|
||||
|
||||
|
||||
_fetched_file_lists_this_session: set[str] = set()
|
||||
|
||||
|
||||
@@ -255,14 +183,6 @@ async def fetch_file_list_with_cache(
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
local_file_list = await _build_file_list_from_local_directory(
|
||||
model_id, recursive
|
||||
)
|
||||
if local_file_list is not None:
|
||||
logger.warning(
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
)
|
||||
return local_file_list
|
||||
raise FileNotFoundError(
|
||||
f"No internet connection and no cached file list for {model_id}"
|
||||
)
|
||||
@@ -283,18 +203,10 @@ async def fetch_file_list_with_cache(
|
||||
except Exception as e:
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
local_file_list = await _build_file_list_from_local_directory(
|
||||
model_id, recursive
|
||||
)
|
||||
if local_file_list is not None:
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id} and no cache exists, "
|
||||
)
|
||||
return local_file_list
|
||||
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
|
||||
|
||||
|
||||
@@ -466,14 +378,10 @@ async def download_file_with_retry(
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
except Exception as e:
|
||||
on_connection_lost()
|
||||
if attempt == n_attempts - 1:
|
||||
on_connection_lost()
|
||||
raise e
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
break
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
|
||||
@@ -195,10 +195,6 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
_, progress = await download_shard(
|
||||
shard,
|
||||
self.on_progress_wrapper,
|
||||
skip_download=True,
|
||||
skip_internet=not self.internet_connection,
|
||||
on_connection_lost=lambda: self.set_internet_connection(False),
|
||||
shard, self.on_progress_wrapper, skip_download=True
|
||||
)
|
||||
return progress
|
||||
|
||||
@@ -27,6 +27,7 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
@@ -105,7 +106,6 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -136,6 +136,7 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -147,8 +148,6 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
@@ -189,9 +188,6 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
|
||||
@@ -176,7 +176,7 @@ async def generate_chat_stream(
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ChatCompletionResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -223,7 +223,7 @@ async def collect_chat_response(
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
yield ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
@@ -241,4 +241,5 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
@@ -123,6 +123,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -529,16 +530,14 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
@@ -633,11 +632,14 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -653,8 +655,7 @@ class API:
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
@@ -856,6 +857,11 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -937,6 +943,11 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1320,40 +1331,29 @@ class API:
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
tg.start_soon(self.run_api, shutdown_ev)
|
||||
try:
|
||||
await anyio.sleep_forever()
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = None
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
with anyio.CancelScope(shield=True):
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
cfg,
|
||||
shutdown_trigger=ev.wait,
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
)
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -6,7 +6,6 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
cancel_unnecessary_downloads,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
@@ -17,12 +16,12 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -38,6 +37,7 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
@@ -68,9 +68,12 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
# Ideally these would be WorkerForwarderEvents but type system says no :(
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -80,7 +83,6 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -96,18 +98,16 @@ class Master:
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -278,16 +278,8 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
@@ -298,7 +290,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -308,7 +300,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -318,6 +310,18 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -326,17 +330,12 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
# rate limit to 1000 at a time
|
||||
for i in range(
|
||||
command.since_idx,
|
||||
min(command.since_idx + 1000, len(self._event_log)),
|
||||
):
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
|
||||
@@ -15,20 +15,20 @@ from exo.master.placement_utils import (
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -186,6 +186,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -201,6 +202,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
@@ -208,29 +221,3 @@ def get_transition_events(
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def cancel_unnecessary_downloads(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> Sequence[DownloadCommand]:
|
||||
commands: list[DownloadCommand] = []
|
||||
currently_downloading = [
|
||||
(k, v.shard_metadata.model_card.model_id)
|
||||
for k, vs in download_status.items()
|
||||
for v in vs
|
||||
if isinstance(v, (DownloadOngoing))
|
||||
]
|
||||
active_models = set(
|
||||
(
|
||||
node_id,
|
||||
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
|
||||
)
|
||||
for instance in instances.values()
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
)
|
||||
for pair in currently_downloading:
|
||||
if pair not in active_models:
|
||||
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
|
||||
|
||||
return commands
|
||||
|
||||
@@ -11,7 +11,6 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -48,7 +47,6 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -69,7 +67,6 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
download_command_sender=fcds,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
@@ -9,7 +9,6 @@ from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -147,21 +146,18 @@ class Router:
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
with move_on_after(1, shield=True):
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -170,12 +166,12 @@ class Router:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
logger.info(f"Subscribed to {topic}")
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
logger.info(f"Unsubscribed from {topic}")
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
|
||||
@@ -86,29 +86,28 @@ class Election:
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
finally:
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election shutdown")
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election finished")
|
||||
|
||||
async def elect(self, em: ElectionMessage) -> None:
|
||||
logger.debug(f"Electing: {em}")
|
||||
|
||||
@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -72,12 +76,7 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
|
||||
|
||||
Command = (
|
||||
@@ -89,6 +88,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +93,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -194,10 +194,9 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(
|
||||
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -13,9 +13,6 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
@@ -28,21 +25,16 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import MiniMaxAttention
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
@@ -511,24 +503,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
|
||||
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
# layer.self_attn.kv_b_proj
|
||||
# )
|
||||
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
@@ -644,84 +624,6 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.group = group
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: "Cache | None" = None,
|
||||
) -> mx.array:
|
||||
batch_dim, seq_dim, _ = x.shape
|
||||
|
||||
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
|
||||
|
||||
queries: mx.array = self._original_layer.q_proj(x)
|
||||
keys: mx.array = self._original_layer.k_proj(x)
|
||||
values: mx.array = self._original_layer.v_proj(x)
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1]
|
||||
k_dim = keys.shape[-1]
|
||||
n = self.group.size()
|
||||
|
||||
qk = mx.concatenate(
|
||||
[queries, keys], axis=-1
|
||||
) # (batch_dim, seq_dim, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (n*batch_dim, seq_dim, q_dim + k_dim)
|
||||
|
||||
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * q_dim)
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * k_dim)
|
||||
|
||||
queries = self._original_layer.q_norm(queries)
|
||||
keys = self._original_layer.k_norm(keys)
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self._original_layer.rope(queries, offset=cache.offset)
|
||||
keys = self._original_layer.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self._original_layer.rope(queries)
|
||||
keys = self._original_layer.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self._original_layer.scale, # type: ignore
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
|
||||
|
||||
return self._original_layer.o_proj(output)
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -730,6 +632,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -740,11 +643,18 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -769,95 +679,18 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
if isinstance(layer, Qwen3DecoderLayer):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
linear_attn = layer.linear_attn
|
||||
|
||||
linear_attn.in_proj_qkvz = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_qkvz
|
||||
)
|
||||
linear_attn.in_proj_ba = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_ba
|
||||
)
|
||||
linear_attn.out_proj = self.sharded_to_all_linear(
|
||||
linear_attn.out_proj
|
||||
)
|
||||
|
||||
# Shard conv1d: depthwise conv with non-contiguous channel slicing.
|
||||
# Channel layout is [q(key_dim), k(key_dim), v(value_dim)].
|
||||
# Each rank takes its head-slice from each of the three sections.
|
||||
rank = self.group.rank()
|
||||
key_dim = linear_attn.key_dim
|
||||
value_dim = linear_attn.value_dim
|
||||
key_dim_shard = key_dim // self.N
|
||||
value_dim_shard = value_dim // self.N
|
||||
|
||||
q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)
|
||||
k_idx = mx.arange(
|
||||
key_dim + rank * key_dim_shard,
|
||||
key_dim + (rank + 1) * key_dim_shard,
|
||||
)
|
||||
v_idx = mx.arange(
|
||||
2 * key_dim + rank * value_dim_shard,
|
||||
2 * key_dim + (rank + 1) * value_dim_shard,
|
||||
)
|
||||
conv_indices = mx.concatenate([q_idx, k_idx, v_idx])
|
||||
linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]
|
||||
new_conv_dim = key_dim_shard * 2 + value_dim_shard
|
||||
linear_attn.conv1d.groups = new_conv_dim
|
||||
|
||||
num_v_shard = linear_attn.num_v_heads // self.N
|
||||
v_start = rank * num_v_shard
|
||||
v_end = v_start + num_v_shard
|
||||
linear_attn.A_log = linear_attn.A_log[v_start:v_end]
|
||||
linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]
|
||||
|
||||
linear_attn.num_k_heads //= self.N
|
||||
linear_attn.num_v_heads //= self.N
|
||||
linear_attn.key_dim = (
|
||||
linear_attn.head_k_dim * linear_attn.num_k_heads
|
||||
)
|
||||
linear_attn.value_dim = (
|
||||
linear_attn.head_v_dim * linear_attn.num_v_heads
|
||||
)
|
||||
linear_attn.conv_dim = (
|
||||
linear_attn.key_dim * 2 + linear_attn.value_dim
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
@@ -867,14 +700,6 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_expert.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_expert.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -24,119 +26,51 @@ _MEMORY_THRESHOLD = float(
|
||||
)
|
||||
|
||||
|
||||
class CacheSnapshot:
|
||||
"""Snapshot of states at a known token position."""
|
||||
|
||||
def __init__(
|
||||
self, states: list[RotatingKVCache | ArraysCache | None], token_count: int
|
||||
):
|
||||
self.states = states
|
||||
self.token_count = token_count
|
||||
|
||||
|
||||
def snapshot_ssm_states(cache: KVCacheType) -> CacheSnapshot:
|
||||
states: list[ArraysCache | RotatingKVCache | None] = []
|
||||
for c in cache:
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
states.append(deepcopy(c))
|
||||
else:
|
||||
states.append(None)
|
||||
token_count = cache_length(cache)
|
||||
return CacheSnapshot(states=states, token_count=token_count)
|
||||
|
||||
|
||||
def _find_nearest_snapshot(
|
||||
snapshots: list[CacheSnapshot],
|
||||
target_token_count: int,
|
||||
) -> CacheSnapshot | None:
|
||||
best: CacheSnapshot | None = None
|
||||
for snap in snapshots:
|
||||
if snap.token_count <= target_token_count and (
|
||||
best is None or snap.token_count > best.token_count
|
||||
):
|
||||
best = snap
|
||||
return best
|
||||
|
||||
|
||||
def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
"""Check if a cache contains any ArraysCache (SSM) entries."""
|
||||
return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
ssm_snapshots: list[CacheSnapshot] | None = None,
|
||||
):
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
self.prompts.append(prompt_tokens)
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._snapshots.append(ssm_snapshots)
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt_tokens: mx.array,
|
||||
prompt: str,
|
||||
cache: KVCacheType,
|
||||
snapshots: list[CacheSnapshot] | None,
|
||||
restore_pos: int,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
old_snapshots = self._snapshots[index]
|
||||
merged: list[CacheSnapshot] = []
|
||||
if old_snapshots:
|
||||
merged = [s for s in old_snapshots if s.token_count <= restore_pos]
|
||||
if snapshots:
|
||||
merged.extend(snapshots)
|
||||
|
||||
self.prompts[index] = prompt_tokens
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._snapshots[index] = merged or None
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
|
||||
|
||||
def _get_snapshot(
|
||||
self, entry_index: int, target_token_count: int
|
||||
) -> tuple[int, CacheSnapshot | None]:
|
||||
if not has_non_kv_caches(self.caches[entry_index]):
|
||||
return target_token_count, None
|
||||
|
||||
snapshots = self._snapshots[entry_index]
|
||||
if not snapshots:
|
||||
return 0, None
|
||||
|
||||
snap = _find_nearest_snapshot(snapshots, target_token_count)
|
||||
if snap is not None:
|
||||
return snap.token_count, snap
|
||||
|
||||
return 0, None
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
prompt_tokens: mx.array,
|
||||
prompt: str,
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
@@ -145,71 +79,76 @@ class KVPrefixCache:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
|
||||
For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
|
||||
nearest SSM snapshot position at or before the match point for correctness.
|
||||
Same for rotating KV Cache.
|
||||
"""
|
||||
max_length = len(prompt_tokens)
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_index: int | None = None
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
|
||||
# Find best cache
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
if best_snapshot_index is not None:
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
cached_length = cache_length(self.caches[best_index])
|
||||
tokens_to_trim = cached_length - restore_pos
|
||||
if tokens_to_trim > 0:
|
||||
trim_cache(prompt_cache, tokens_to_trim, restore_snap)
|
||||
# Reset cache offset to match trimmed length
|
||||
for c in prompt_cache:
|
||||
if hasattr(c, "offset"):
|
||||
c.offset = restore_pos
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
|
||||
self._access_counter += 1
|
||||
self._last_used[best_index] = self._access_counter
|
||||
remaining = prompt_tokens[restore_pos:]
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
|
||||
return prompt_cache, remaining, best_index
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while (
|
||||
len(self.caches) > 0
|
||||
len(self.caches) > 1
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._snapshots.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
@@ -230,21 +169,6 @@ class KVPrefixCache:
|
||||
return max_pressure
|
||||
|
||||
|
||||
def trim_cache(
|
||||
cache: KVCacheType,
|
||||
num_tokens: int,
|
||||
snapshot: CacheSnapshot | None = None,
|
||||
) -> None:
|
||||
for i, c in enumerate(cache):
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
if snapshot is not None and snapshot.states[i] is not None:
|
||||
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
|
||||
else:
|
||||
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
else:
|
||||
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
@@ -253,14 +177,14 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(prompt_tokens)
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
@@ -291,7 +215,7 @@ def make_kv_cache(
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache"):
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -24,14 +23,7 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
encode_prompt,
|
||||
has_non_kv_caches,
|
||||
make_kv_cache,
|
||||
snapshot_ssm_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
@@ -40,8 +32,6 @@ from exo.worker.engines.mlx.constants import (
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -56,7 +46,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
) -> tuple[float, int]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -67,21 +57,17 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0, 0, []
|
||||
return 0.0, 0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
has_ssm = has_non_kv_caches(cache)
|
||||
snapshots: list[CacheSnapshot] = []
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
elapsed = time.time() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
@@ -98,18 +84,7 @@ def prefill(
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
|
||||
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
|
||||
for i, c in enumerate(cache):
|
||||
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
assert pre_gen is not None
|
||||
if pre_gen.states[i] is not None:
|
||||
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
@@ -117,14 +92,12 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
# Exclude the last snapshot
|
||||
return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else []
|
||||
return tokens_per_sec, num_tokens
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -143,7 +116,7 @@ def warmup_inference(
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.0)
|
||||
sampler = make_sampler(temp=0.7)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
@@ -162,8 +135,6 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
@@ -245,17 +216,11 @@ def mlx_generate(
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
|
||||
seed = task.seed or 42
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode prompt once at the top and fix unmatched think tags
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
is_bench = task.bench
|
||||
@@ -267,16 +232,13 @@ def mlx_generate(
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = all_prompt_tokens
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, all_prompt_tokens
|
||||
model, prompt
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
if prefix_hit_length > 0:
|
||||
logger.info(
|
||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||
)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -299,17 +261,12 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-2:]
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
@@ -337,6 +294,7 @@ def mlx_generate(
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
@@ -404,6 +362,16 @@ def mlx_generate(
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
@@ -417,42 +385,14 @@ def mlx_generate(
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
generated_tokens_array = mx.array(
|
||||
tokenizer.encode(
|
||||
"".join(generated_text_parts), add_special_tokens=False
|
||||
)
|
||||
)
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
full_prompt_tokens,
|
||||
caches,
|
||||
cache_snapshots,
|
||||
restore_pos=prefix_hit_length,
|
||||
)
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(
|
||||
full_prompt_tokens, caches, cache_snapshots
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
mx_barrier(group)
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
|
||||
@@ -67,8 +67,6 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -86,30 +84,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -490,30 +464,6 @@ def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> b
|
||||
return think_token is not None and prompt.rstrip().endswith(think_token)
|
||||
|
||||
|
||||
def fix_unmatched_think_end_tokens(
|
||||
tokens: mx.array, tokenizer: TokenizerWrapper
|
||||
) -> mx.array:
|
||||
if not tokenizer.has_thinking:
|
||||
return tokens
|
||||
assert tokenizer.think_start_id
|
||||
assert tokenizer.think_end_id
|
||||
think_start_id: int = tokenizer.think_start_id
|
||||
think_end_id: int = tokenizer.think_end_id
|
||||
token_list: list[int] = cast(list[int], tokens.tolist())
|
||||
result: list[int] = []
|
||||
depth = 0
|
||||
for token in token_list:
|
||||
if token == think_start_id:
|
||||
depth += 1
|
||||
elif token == think_end_id:
|
||||
if depth == 0:
|
||||
result.append(think_start_id)
|
||||
else:
|
||||
depth -= 1
|
||||
result.append(token)
|
||||
return mx.array(result)
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
"""
|
||||
A KVCache that pretends to exist but holds zero tokens.
|
||||
@@ -586,3 +536,23 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -98,23 +99,22 @@ class Worker:
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
finally:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
logger.info("Stopping Worker")
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
async with create_task_group() as tg:
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
tg.start_soon(runner.shutdown)
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -218,15 +218,22 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
await runner.start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -264,18 +271,18 @@ class Worker:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
await self._start_runner_task(modified_task)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
await self._start_runner_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -314,8 +321,6 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -53,13 +54,14 @@ def plan(
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
)
|
||||
|
||||
|
||||
@@ -270,7 +272,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -284,7 +286,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -292,16 +294,33 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner_id, runner in runners.items():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id,
|
||||
cancelled_task_id=task.task_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,6 +15,7 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -38,7 +39,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
@@ -87,6 +88,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -111,6 +113,7 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
@@ -125,11 +128,15 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -142,6 +149,7 @@ def main(
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -187,19 +195,19 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -207,8 +215,6 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -220,16 +226,30 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
t = time.perf_counter()
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / (time.perf_counter() - t)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
mx.max(
|
||||
mx.distributed.all_gather(
|
||||
mx.array([check_for_cancel_every]), group=group
|
||||
)
|
||||
).item()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||
)
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
@@ -237,8 +257,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -258,9 +278,9 @@ def main(
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
@@ -270,12 +290,11 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
@@ -295,11 +314,11 @@ def main(
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
inference_model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -312,7 +331,18 @@ def main(
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
@@ -384,7 +414,7 @@ def main(
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -397,7 +427,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
@@ -447,7 +479,7 @@ def main(
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -460,7 +492,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -526,7 +560,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
@@ -629,7 +663,7 @@ def parse_thinking_models(
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id,
|
||||
"token": tokenizer.think_start_id, # type: ignore
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
@@ -8,8 +8,10 @@ import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
to_thread,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
@@ -47,9 +49,12 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -60,8 +65,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -69,6 +74,7 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -83,6 +89,7 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -90,35 +97,42 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
await self._forward_events()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
@@ -141,6 +155,13 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
@@ -205,4 +226,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
await self.shutdown()
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache()
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,12 +142,10 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
# Snapshots should be available for models with non-KV caches
|
||||
assert len(snapshots) > 0
|
||||
# Cache should now hold the prompt tokens
|
||||
assert cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -161,10 +159,10 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -172,7 +170,7 @@ class TestKVPrefixCacheWithModel:
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, tokens
|
||||
model, prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
@@ -193,12 +191,10 @@ class TestKVPrefixCacheWithModel:
|
||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
)
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = TextGenerationTaskParams(
|
||||
@@ -216,12 +212,13 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_tokens
|
||||
model, long_prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# remaining_tokens covers from snapshot restore position to end
|
||||
assert len(remaining_tokens) >= len(long_tokens) - expected_prefix
|
||||
# remaining_tokens should be the suffix after the shared prefix
|
||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
||||
|
||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
||||
self, model_and_tokenizer
|
||||
@@ -238,15 +235,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
@@ -276,15 +273,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
@@ -301,7 +298,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -331,7 +328,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -355,20 +352,20 @@ class TestKVPrefixCacheWithModel:
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt_tokens
|
||||
model, prompt
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
assert matched_index == 0
|
||||
# Exact match: remaining_tokens is just the last token and the one before
|
||||
assert len(remaining_tokens) == 2
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-2:])
|
||||
# Exact match: remaining_tokens is just the last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
||||
|
||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -447,7 +444,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -484,7 +481,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -498,7 +495,7 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
@@ -508,10 +505,19 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)
|
||||
with patch(
|
||||
"exo.worker.engines.mlx.cache.get_memory_used_percentage",
|
||||
return_value=0.95,
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = TextGenerationTaskParams(
|
||||
@@ -523,11 +529,14 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
|
||||
@@ -34,7 +34,6 @@ TOKENIZER_FILE_PATTERNS = [
|
||||
"added_tokens.json",
|
||||
"tokenizer.model",
|
||||
"tokenization_*.py", # Custom tokenizer implementations
|
||||
"tool_declaration_ts.py", # Dependency of tokenization_kimi.py
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
@@ -19,6 +20,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -113,6 +115,8 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
@@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]):
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||
event_sender = EventCollector()
|
||||
|
||||
with task_sender:
|
||||
@@ -174,7 +179,7 @@ def _run(tasks: Iterable[Task]):
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ $# -lt 1 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
sleep 1
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
}
|
||||
trap 'cleanup' EXIT INT TERM
|
||||
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
|
||||
done
|
||||
wait
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..." 1>&2
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
|
||||
echo "Waiting 30s for cluster setup" 1>&2
|
||||
sleep 30
|
||||
echo "EXO loaded" 1>&2
|
||||
bench_runner="${hosts[0]}"
|
||||
mkdir -p "./bench/$commit"
|
||||
nix run .#exo-get-all-models-on-cluster -- "$bench_runner" | while IFS= read -r model; do
|
||||
echo "running bench for $model" 1>&2
|
||||
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$bench_runner@$bench_runner" "/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-bench -- --model $model --pp 128 4096 --tg 128 --stdout --skip-tensor-ring" >>"./bench/$commit/${model//\//--}.json"
|
||||
echo
|
||||
done
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from urllib.request import urlopen
|
||||
|
||||
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
|
||||
ts = subprocess.run(
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = next(
|
||||
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None
|
||||
) or sys.exit(f"{h} not found in tailscale")
|
||||
with urlopen(f"http://{ip}:52415/state", timeout=5) as r:
|
||||
data = json.loads(r.read()).get("downloads", {})
|
||||
|
||||
|
||||
def mid(x: dict[str, Any]) -> str | None:
|
||||
for k in (
|
||||
"DownloadCompleted",
|
||||
"shardMetadata",
|
||||
"PipelineShardMetadata",
|
||||
"modelCard",
|
||||
"modelId",
|
||||
):
|
||||
x = x.get(k, {})
|
||||
return cast(str | None, x if x != {} else None)
|
||||
|
||||
|
||||
common = set[str].intersection(
|
||||
*[{m for d in nid if (m := mid(d))} for nid in data.values()]
|
||||
)
|
||||
for c in common:
|
||||
print(c)
|
||||
@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
|
||||
done
|
||||
wait
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
@@ -34,13 +34,21 @@ reset=$'\e[0m'
|
||||
i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
{
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
|
||||
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
} <<'EOF'
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q "$1"
|
||||
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
|
||||
EOF
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
until curl -sf "http://$host:52415/models"; do sleep 1; done
|
||||
done
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user