mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-28 12:17:24 -05:00
Compare commits
5 Commits
remove-cus
...
crash-avoi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17233f48ce | ||
|
|
635118ef24 | ||
|
|
dc0bb5e13b | ||
|
|
152a27ea5d | ||
|
|
db36bd5ac6 |
@@ -73,9 +73,11 @@ class GenerationResponse:
|
||||
finish_reason: Optional[str] = ...
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
): # -> None:
|
||||
...
|
||||
prompt_cache: Any,
|
||||
quantized_kv_start: int | None,
|
||||
kv_group_size: int | None,
|
||||
kv_bits: int | None,
|
||||
) -> None: ...
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
|
||||
@@ -16,7 +16,7 @@ class Cache(Protocol):
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
|
||||
@@ -92,13 +92,14 @@ class _BaseCache(Cache):
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
@property
|
||||
def meta_state(self) -> Literal[""]: ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v) -> None: ...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def is_trimmable(self) -> Literal[False]: ...
|
||||
@classmethod
|
||||
def from_state(cls, state, meta_state) -> Self: ...
|
||||
@@ -114,15 +115,13 @@ class ConcatenateKVCache(_BaseCache):
|
||||
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||
...
|
||||
@property
|
||||
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
@@ -132,10 +131,7 @@ class QuantizedKVCache(_BaseCache):
|
||||
def update_and_fetch(self, keys, values): # -> Any:
|
||||
...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||
...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -147,8 +143,7 @@ class QuantizedKVCache(_BaseCache):
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
@@ -160,13 +155,12 @@ class KVCache(_BaseCache):
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[array, array]: ...
|
||||
) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
@@ -183,8 +177,7 @@ class RotatingKVCache(_BaseCache):
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -196,8 +189,7 @@ class RotatingKVCache(_BaseCache):
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
@@ -212,8 +204,7 @@ class ArraysCache(_BaseCache):
|
||||
...
|
||||
def __getitem__(self, idx): ...
|
||||
@property
|
||||
def state(self): # -> list[Any | array] | list[array]:
|
||||
...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -239,8 +230,7 @@ class ChunkedKVCache(KVCache):
|
||||
...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@@ -253,10 +243,9 @@ class CacheList(_BaseCache):
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): ...
|
||||
def trim(self, n: int) -> int: ...
|
||||
@property
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
[workspace]
|
||||
resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/util",
|
||||
]
|
||||
members = ["rust/networking", "rust/exo_pyo3_bindings", "rust/util"]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.0.1"
|
||||
|
||||
@@ -2,6 +2,4 @@
|
||||
#
|
||||
# Lists the suite files to include. Each file defines benchmarks
|
||||
# with shared constraints, topology, and default args.
|
||||
include = [
|
||||
"single-m3-ultra.toml",
|
||||
]
|
||||
include = ["single-m3-ultra.toml"]
|
||||
|
||||
@@ -4,13 +4,13 @@ version = "0.1.0"
|
||||
description = "Benchmarking tool for exo distributed inference"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"loguru>=0.7.3",
|
||||
"transformers>=5.0.0",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"tiktoken>=0.12.0",
|
||||
"jinja2>=3.1.0",
|
||||
"protobuf>=5.29.0",
|
||||
"httpx>=0.27.0",
|
||||
"loguru>=0.7.3",
|
||||
"transformers>=5.0.0",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"tiktoken>=0.12.0",
|
||||
"jinja2>=3.1.0",
|
||||
"protobuf>=5.29.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
#
|
||||
# Shared constraints applied to ALL benchmarks in this file.
|
||||
constraints = [
|
||||
"All(MacOsBuild(=25D125))",
|
||||
"Hosts(=1)",
|
||||
"All(Chip(m3_ultra))",
|
||||
"All(GpuCores(=80))",
|
||||
"All(MacOsBuild(=25D125))",
|
||||
"Hosts(=1)",
|
||||
"All(Chip(m3_ultra))",
|
||||
"All(GpuCores(=80))",
|
||||
]
|
||||
|
||||
[topology]
|
||||
|
||||
@@ -3158,6 +3158,23 @@ class AppStore {
|
||||
return (await response.json()) as TraceStatsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete traces by task IDs
|
||||
*/
|
||||
async deleteTraces(
|
||||
taskIds: string[],
|
||||
): Promise<{ deleted: string[]; notFound: string[] }> {
|
||||
const response = await fetch("/v1/traces/delete", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ taskIds }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to delete traces: ${response.status}`);
|
||||
}
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the URL for the raw trace file (for Perfetto)
|
||||
*/
|
||||
@@ -3301,3 +3318,5 @@ export const fetchTraceStats = (taskId: string) =>
|
||||
appStore.fetchTraceStats(taskId);
|
||||
export const getTraceRawUrl = (taskId: string) =>
|
||||
appStore.getTraceRawUrl(taskId);
|
||||
export const deleteTraces = (taskIds: string[]) =>
|
||||
appStore.deleteTraces(taskIds);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import {
|
||||
listTraces,
|
||||
getTraceRawUrl,
|
||||
deleteTraces,
|
||||
type TraceListItem,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
@@ -10,6 +11,51 @@
|
||||
let traces = $state<TraceListItem[]>([]);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
let selectedIds = $state<Set<string>>(new Set());
|
||||
let deleting = $state(false);
|
||||
|
||||
let allSelected = $derived(
|
||||
traces.length > 0 && selectedIds.size === traces.length,
|
||||
);
|
||||
|
||||
function toggleSelect(taskId: string) {
|
||||
const next = new Set(selectedIds);
|
||||
if (next.has(taskId)) {
|
||||
next.delete(taskId);
|
||||
} else {
|
||||
next.add(taskId);
|
||||
}
|
||||
selectedIds = next;
|
||||
}
|
||||
|
||||
function toggleSelectAll() {
|
||||
if (allSelected) {
|
||||
selectedIds = new Set();
|
||||
} else {
|
||||
selectedIds = new Set(traces.map((t) => t.taskId));
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDelete() {
|
||||
if (selectedIds.size === 0) return;
|
||||
const count = selectedIds.size;
|
||||
if (
|
||||
!confirm(
|
||||
`Delete ${count} trace${count === 1 ? "" : "s"}? This cannot be undone.`,
|
||||
)
|
||||
)
|
||||
return;
|
||||
deleting = true;
|
||||
try {
|
||||
await deleteTraces([...selectedIds]);
|
||||
selectedIds = new Set();
|
||||
await refresh();
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to delete traces";
|
||||
} finally {
|
||||
deleting = false;
|
||||
}
|
||||
}
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return "0B";
|
||||
@@ -109,6 +155,16 @@
|
||||
</h1>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
{#if selectedIds.size > 0}
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-red-400 hover:text-red-300 transition-colors uppercase border border-red-500/40 px-2 py-1 rounded"
|
||||
onclick={handleDelete}
|
||||
disabled={deleting}
|
||||
>
|
||||
{deleting ? "Deleting..." : `Delete (${selectedIds.size})`}
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
@@ -143,14 +199,41 @@
|
||||
</div>
|
||||
{:else}
|
||||
<div class="space-y-3">
|
||||
<div class="flex items-center gap-2 px-1">
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono uppercase transition-colors {allSelected
|
||||
? 'text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
onclick={toggleSelectAll}
|
||||
>
|
||||
{allSelected ? "Deselect all" : "Select all"}
|
||||
</button>
|
||||
</div>
|
||||
{#each traces as trace}
|
||||
{@const isSelected = selectedIds.has(trace.taskId)}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="w-full text-left rounded border-l-2 border-r border-t border-b transition-all p-4 flex items-center justify-between gap-4 cursor-pointer {isSelected
|
||||
? 'bg-exo-yellow/10 border-l-exo-yellow border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30'
|
||||
: 'bg-exo-black/30 border-l-transparent border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30 hover:bg-white/[0.03]'}"
|
||||
onclick={() => toggleSelect(trace.taskId)}
|
||||
onkeydown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
toggleSelect(trace.taskId);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
|
||||
class="text-sm font-mono transition-colors truncate block {isSelected
|
||||
? 'text-exo-yellow'
|
||||
: 'text-white hover:text-exo-yellow'}"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{trace.taskId}
|
||||
</a>
|
||||
@@ -160,7 +243,11 @@
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 shrink-0">
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<div
|
||||
class="flex items-center gap-2 shrink-0"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
|
||||
@@ -108,6 +108,7 @@
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
shfmt.enable = true;
|
||||
taplo.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
100
pyproject.toml
100
pyproject.toml
@@ -5,31 +5,31 @@ description = "Exo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"pydantic>=2.11.7",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.7",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux==0.15.5",
|
||||
"python-multipart>=0.0.21",
|
||||
"msgspec>=0.19.0",
|
||||
"zstandard>=0.23.0",
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"pydantic>=2.11.7",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.7",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux==0.15.5",
|
||||
"python-multipart>=0.0.21",
|
||||
"msgspec>=0.19.0",
|
||||
"zstandard>=0.23.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -38,12 +38,12 @@ exo = "exo.main:main"
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.29.0",
|
||||
"pyinstaller>=6.17.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-env",
|
||||
"ruff>=0.11.13",
|
||||
"basedpyright>=1.29.0",
|
||||
"pyinstaller>=6.17.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-env",
|
||||
"ruff>=0.11.13",
|
||||
]
|
||||
|
||||
# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.
|
||||
@@ -57,10 +57,7 @@ dev = [
|
||||
###
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"rust/exo_pyo3_bindings",
|
||||
"bench",
|
||||
]
|
||||
members = ["rust/exo_pyo3_bindings", "bench"]
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
@@ -95,7 +92,15 @@ reportUnnecessaryTypeIgnoreComment = "error"
|
||||
pythonVersion = "3.13"
|
||||
pythonPlatform = "Darwin"
|
||||
|
||||
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust", "**/.github"]
|
||||
exclude = [
|
||||
"**/.venv",
|
||||
"**/venv",
|
||||
"**/__pycache__",
|
||||
"**/exo_scripts",
|
||||
"**/.direnv",
|
||||
"**/rust",
|
||||
"**/.github",
|
||||
]
|
||||
stubPath = ".mlx_typings"
|
||||
|
||||
[[tool.basedpyright.executionEnvironments]]
|
||||
@@ -109,17 +114,18 @@ root = "src"
|
||||
[tool.uv]
|
||||
required-version = ">=0.8.6"
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
]
|
||||
environments = ["sys_platform == 'darwin'", "sys_platform == 'linux'"]
|
||||
|
||||
###
|
||||
# ruff configuration
|
||||
###
|
||||
|
||||
[tool.ruff]
|
||||
extend-exclude = ["shared/protobufs/**", "*mlx_typings/**", "rust/exo_pyo3_bindings/**"]
|
||||
extend-exclude = [
|
||||
"shared/protobufs/**",
|
||||
"*mlx_typings/**",
|
||||
"rust/exo_pyo3_bindings/**",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["I", "N", "B", "A", "PIE", "SIM"]
|
||||
@@ -127,13 +133,7 @@ extend-select = ["I", "N", "B", "A", "PIE", "SIM"]
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = "."
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselected by default)"
|
||||
]
|
||||
env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
markers = ["slow: marks tests as slow (deselected by default)"]
|
||||
env = ["EXO_TESTS=1"]
|
||||
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
filterwarnings = ["ignore:builtin type Swig:DeprecationWarning"]
|
||||
|
||||
@@ -26,20 +26,24 @@ networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
# integrations with other libraries
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-async-runtimes = { version = "0.27.0", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
"testing",
|
||||
] }
|
||||
pyo3-log = "0.13.2"
|
||||
|
||||
# macro dependencies
|
||||
|
||||
@@ -8,18 +8,14 @@ version = "0.2.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
{ name = "Evan Quiney", email = "evanev7@gmail.com" }
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
{ name = "Evan Quiney", email = "evanev7@gmail.com" },
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"exo_pyo3_bindings",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
]
|
||||
dev = ["exo_pyo3_bindings", "pytest>=8.4.0", "pytest-asyncio>=1.0.0"]
|
||||
|
||||
[tool.maturin]
|
||||
#purelib = true
|
||||
|
||||
@@ -28,7 +28,10 @@ tokio = { workspace = true, features = ["full"] }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = [
|
||||
"default",
|
||||
"env-filter",
|
||||
] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
|
||||
@@ -314,9 +314,13 @@ async def fetch_file_list_with_cache(
|
||||
_fetched_file_lists_this_session.add(cache_key)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
"Ran into exception when fetching file list from HF."
|
||||
)
|
||||
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
f"No cached file list for {model_id} - using local file list"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
|
||||
@@ -81,6 +81,8 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DeleteTracesRequest,
|
||||
DeleteTracesResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -340,6 +342,7 @@ class API:
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.post("/v1/traces/delete")(self.delete_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
@@ -1707,8 +1710,12 @@ class API:
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
@staticmethod
|
||||
def _get_trace_path(task_id: str) -> Path:
|
||||
trace_path = EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
if not trace_path.resolve().is_relative_to(EXO_TRACING_CACHE_DIR.resolve()):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid task ID: {task_id}")
|
||||
return trace_path
|
||||
|
||||
async def list_traces(self) -> TraceListResponse:
|
||||
traces: list[TraceListItem] = []
|
||||
@@ -1807,6 +1814,18 @@ class API:
|
||||
filename=f"trace_{task_id}.json",
|
||||
)
|
||||
|
||||
async def delete_traces(self, request: DeleteTracesRequest) -> DeleteTracesResponse:
|
||||
deleted: list[str] = []
|
||||
not_found: list[str] = []
|
||||
for task_id in request.task_ids:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
if trace_path.exists():
|
||||
trace_path.unlink()
|
||||
deleted.append(task_id)
|
||||
else:
|
||||
not_found.append(task_id)
|
||||
return DeleteTracesResponse(deleted=deleted, not_found=not_found)
|
||||
|
||||
async def get_onboarding(self) -> JSONResponse:
|
||||
return JSONResponse({"completed": ONBOARDING_COMPLETE_FILE.exists()})
|
||||
|
||||
|
||||
@@ -328,17 +328,22 @@ class Master:
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Nonexistent command {command.cancelled_command_id} cancelled"
|
||||
)
|
||||
)
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case TaskFinished():
|
||||
if (
|
||||
task_id := self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(TaskDeleted(task_id=task_id))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Finished command {command.finished_command_id} finished"
|
||||
)
|
||||
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
# rate limit to 1000 at a time
|
||||
|
||||
@@ -437,3 +437,12 @@ class TraceListItem(CamelCaseModel):
|
||||
|
||||
class TraceListResponse(CamelCaseModel):
|
||||
traces: list[TraceListItem]
|
||||
|
||||
|
||||
class DeleteTracesRequest(CamelCaseModel):
|
||||
task_ids: list[str]
|
||||
|
||||
|
||||
class DeleteTracesResponse(CamelCaseModel):
|
||||
deleted: list[str]
|
||||
not_found: list[str]
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx import core as mx
|
||||
from mlx import nn as nn
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
@@ -14,3 +16,16 @@ from mlx_lm.models.cache import (
|
||||
KVCacheType = Sequence[
|
||||
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
|
||||
]
|
||||
|
||||
|
||||
# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: KVCacheType | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
|
||||
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: list[KVCache] | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
@@ -49,6 +49,21 @@ TimeoutCallback = Callable[[], None]
|
||||
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
|
||||
|
||||
|
||||
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
|
||||
|
||||
|
||||
def flush_prefill_sends() -> None:
|
||||
for output, dst, group in _pending_prefill_sends:
|
||||
sent = mx.distributed.send(output, dst, group=group)
|
||||
mx.async_eval(sent)
|
||||
_pending_prefill_sends.clear()
|
||||
|
||||
|
||||
def clear_prefill_sends() -> None:
|
||||
# Discard pending sends (e.g. on cancellation).
|
||||
_pending_prefill_sends.clear()
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
@@ -150,6 +165,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
self.group = group
|
||||
self.original_layer_signature = signature(self.original_layer.__call__)
|
||||
self.is_prefill: bool = False
|
||||
self.queue_sends: bool = False
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
@@ -163,9 +179,14 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
mx.eval(output)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if self.queue_sends:
|
||||
_pending_prefill_sends.append(
|
||||
(output, (self.r + 1) % self.s, self.group)
|
||||
)
|
||||
else:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
@@ -190,6 +211,12 @@ def set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None:
|
||||
layer.is_prefill = is_prefill
|
||||
|
||||
|
||||
def set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None:
|
||||
for layer in model.layers: # type: ignore
|
||||
if isinstance(layer, PipelineLastLayer):
|
||||
layer.queue_sends = queue_sends
|
||||
|
||||
|
||||
def get_inner_model(model: nn.Module) -> nn.Module:
|
||||
inner = getattr(model, "model", None)
|
||||
if isinstance(inner, nn.Module):
|
||||
|
||||
@@ -13,8 +13,7 @@ from mlx_lm.models.cache import (
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.shared.types.mlx import KVCacheType, Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -254,9 +253,9 @@ def trim_cache(
|
||||
if snapshot is not None and snapshot.states[i] is not None:
|
||||
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
|
||||
else:
|
||||
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
c.state = [None] * len(c.state)
|
||||
else:
|
||||
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
|
||||
c.trim(num_tokens)
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import functools
|
||||
import math
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.generate import (
|
||||
maybe_quantize_kv_cache,
|
||||
stream_generate,
|
||||
)
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
@@ -19,13 +23,19 @@ from exo.shared.types.api import (
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.mlx import KVCacheType, Model
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
clear_prefill_sends,
|
||||
flush_prefill_sends,
|
||||
set_pipeline_prefill,
|
||||
set_pipeline_queue_sends,
|
||||
)
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
@@ -56,6 +66,130 @@ class PrefillCancelled(BaseException):
|
||||
"""Raised when prefill is cancelled via the progress callback."""
|
||||
|
||||
|
||||
def _has_pipeline_communication_layer(model: Model):
|
||||
for layer in model.layers:
|
||||
if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def pipeline_parallel_prefill(
|
||||
model: Model,
|
||||
prompt: mx.array,
|
||||
prompt_cache: KVCacheType,
|
||||
prefill_step_size: int,
|
||||
kv_group_size: int | None,
|
||||
kv_bits: int | None,
|
||||
prompt_progress_callback: Callable[[int, int], None],
|
||||
distributed_prompt_progress_callback: Callable[[], None] | None,
|
||||
group: mx.distributed.Group,
|
||||
) -> None:
|
||||
"""Prefill the KV cache for pipeline parallel with overlapping stages.
|
||||
|
||||
Each rank processes the full prompt through its real cache, offset by leading
|
||||
and trailing dummy iterations.
|
||||
|
||||
Total iterations per rank = N_real_chunks + world_size - 1:
|
||||
- rank r leading dummies (skip_pipeline_io, throwaway cache)
|
||||
- N_real_chunks real (pipeline IO active, real cache)
|
||||
- (world_size-1-r) trailing dummies (skip_pipeline_io, throwaway cache)
|
||||
|
||||
e.g.
|
||||
Timeline (2 ranks, 3 chunks of 10240 tokens @ step=4096):
|
||||
iter 0: R0 real[0:4096] R1 dummy
|
||||
iter 1: R0 real[4096:8192] R1 real[0:4096]
|
||||
iter 2: R0 real[8192:10240] R1 real[4096:8192]
|
||||
iter 3: R0 dummy R1 real[8192:10240]
|
||||
|
||||
This function is designed to match mlx_lm's stream_generate exactly in terms of
|
||||
side effects (given the same prefill step size)
|
||||
"""
|
||||
prefill_step_size = prefill_step_size // min(4, group.size())
|
||||
|
||||
quantize_cache_fn: Callable[..., None] = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=0,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
_prompt_cache: KVCacheType = prompt_cache
|
||||
rank = group.rank()
|
||||
world_size = group.size()
|
||||
|
||||
# Build list of real prompt chunk sizes
|
||||
total = len(prompt)
|
||||
real_chunk_sizes: list[int] = []
|
||||
remaining = total - 1
|
||||
while remaining:
|
||||
n = min(prefill_step_size, remaining)
|
||||
real_chunk_sizes.append(n)
|
||||
remaining -= n
|
||||
n_real = len(real_chunk_sizes)
|
||||
|
||||
# Each rank does: [rank leading dummies] [N real chunks] [world_size-1-rank trailing dummies]
|
||||
n_leading = rank
|
||||
n_trailing = world_size - 1 - rank
|
||||
n_total = n_leading + n_real + n_trailing
|
||||
|
||||
t_start = time.perf_counter()
|
||||
processed = 0
|
||||
logger.info(
|
||||
f"[R{rank}] Pipeline prefill: {n_real} real + {n_leading} leading + {n_trailing} trailing = {n_total} iterations"
|
||||
)
|
||||
clear_prefill_sends()
|
||||
|
||||
# Initial callback matching generate_step
|
||||
prompt_progress_callback(0, total)
|
||||
|
||||
try:
|
||||
with mx.stream(generation_stream):
|
||||
for _ in range(n_leading):
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
|
||||
for i in range(n_real):
|
||||
chunk_size = real_chunk_sizes[i]
|
||||
model(
|
||||
prompt[processed : processed + chunk_size][None],
|
||||
cache=_prompt_cache,
|
||||
)
|
||||
quantize_cache_fn(_prompt_cache)
|
||||
processed += chunk_size
|
||||
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
|
||||
flush_prefill_sends()
|
||||
|
||||
prompt_progress_callback(processed, total)
|
||||
|
||||
for _ in range(n_trailing):
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
|
||||
finally:
|
||||
clear_prefill_sends()
|
||||
|
||||
# Post-loop: process remaining 1 token + add +1 entry to match stream_generate.
|
||||
for _ in range(2):
|
||||
with mx.stream(generation_stream):
|
||||
model(prompt[-1:][None], cache=_prompt_cache)
|
||||
quantize_cache_fn(_prompt_cache)
|
||||
flush_prefill_sends()
|
||||
|
||||
assert _prompt_cache is not None
|
||||
mx.eval([c.state for c in _prompt_cache]) # type: ignore
|
||||
|
||||
# Final callback matching generate_step
|
||||
prompt_progress_callback(total, total)
|
||||
|
||||
logger.info(
|
||||
f"[R{rank}] Prefill: {n_real} real + {n_leading}+{n_trailing} dummy iterations, "
|
||||
f"Processed {processed} tokens in {(time.perf_counter() - t_start) * 1000:.1f}ms"
|
||||
)
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -64,6 +198,7 @@ def prefill(
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None,
|
||||
distributed_prompt_progress_callback: Callable[[], None] | None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -95,31 +230,57 @@ def prefill(
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
def combined_progress_callback(processed: int, total: int) -> None:
|
||||
if distributed_prompt_progress_callback is not None:
|
||||
distributed_prompt_progress_callback()
|
||||
progress_callback(processed, total)
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
is_pipeline = _has_pipeline_communication_layer(model)
|
||||
|
||||
prefill_step_size = 4096
|
||||
|
||||
try:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=4096,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
if is_pipeline and num_tokens >= prefill_step_size:
|
||||
set_pipeline_queue_sends(model, queue_sends=True)
|
||||
assert group is not None, "Pipeline prefill requires a distributed group"
|
||||
pipeline_parallel_prefill(
|
||||
model=model,
|
||||
prompt=prompt_tokens,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||
group=group,
|
||||
)
|
||||
else:
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=combined_progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
except PrefillCancelled:
|
||||
set_pipeline_queue_sends(model, queue_sends=False)
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
raise
|
||||
|
||||
set_pipeline_queue_sends(model, queue_sends=False)
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
@@ -132,7 +293,7 @@ def prefill(
|
||||
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
c.trim(2)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
@@ -275,6 +436,7 @@ def mlx_generate(
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
distributed_prompt_progress_callback: Callable[[], None] | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -336,6 +498,7 @@ def mlx_generate(
|
||||
caches,
|
||||
group,
|
||||
on_prefill_progress,
|
||||
distributed_prompt_progress_callback,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from pydantic import RootModel
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
@@ -52,7 +53,6 @@ from exo.shared.types.worker.shards import (
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
LayerLoadedCallback,
|
||||
TimeoutCallback,
|
||||
|
||||
@@ -31,6 +31,7 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
@@ -63,7 +64,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
PrefillCancelled,
|
||||
@@ -274,8 +274,6 @@ def main(
|
||||
def on_prefill_progress(
|
||||
processed: int,
|
||||
total: int,
|
||||
_task_id: TaskId = task.task_id,
|
||||
_group: mx.distributed.Group | None = group,
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
@@ -288,6 +286,11 @@ def main(
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def distributed_prompt_progress_callback(
|
||||
_task_id: TaskId = task.task_id,
|
||||
_group: mx.distributed.Group | None = group,
|
||||
) -> None:
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (_task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
@@ -309,6 +312,7 @@ def main(
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||
group=group,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@ from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
cache_length,
|
||||
@@ -143,7 +143,14 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
@@ -164,7 +171,14 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -200,7 +214,14 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
short_tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -245,7 +266,14 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -285,7 +313,14 @@ class TestKVPrefixCacheWithModel:
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(None)
|
||||
@@ -513,7 +548,16 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
@@ -538,7 +582,16 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
|
||||
prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
make_sampler(0.0),
|
||||
tokens,
|
||||
cache,
|
||||
group=None,
|
||||
on_prefill_progress=None,
|
||||
distributed_prompt_progress_callback=None,
|
||||
)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
|
||||
@@ -0,0 +1,512 @@
|
||||
# type: ignore
|
||||
"""Test that pipeline prefill callbacks and output exactly match stream_generate.
|
||||
|
||||
Spins up a single-device (non-pipeline) run and a distributed pipeline run,
|
||||
then verifies that the prompt_progress_callback sequences are identical
|
||||
and that generated text matches.
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
MODEL_PATH = EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8"
|
||||
TOTAL_LAYERS = 24
|
||||
MAX_TOKENS = 10
|
||||
SEED = 42
|
||||
TEMPERATURE = 0.0
|
||||
|
||||
|
||||
def _model_card() -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(MODEL_ID),
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=TOTAL_LAYERS,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _build_prompt(tokenizer: Any, prompt_tokens: int) -> tuple[str, Any]:
|
||||
"""Build a prompt with the given number of user-content tokens, return (chat_prompt, task)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_toks = tokenizer.encode(base_text)
|
||||
repeats = (prompt_tokens // len(base_toks)) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
task = TextGenerationTaskParams(
|
||||
model=MODEL_ID,
|
||||
input=[InputMessage(role="user", content=prompt_text)],
|
||||
max_output_tokens=MAX_TOKENS,
|
||||
temperature=TEMPERATURE,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
return prompt, task
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-device process: uses stream_generate path (no pipeline layers)
|
||||
# ---------------------------------------------------------------------------
|
||||
def _run_single_device(
|
||||
prompt_tokens: int,
|
||||
result_queue: Any,
|
||||
) -> None:
|
||||
"""Load full model without pipeline sharding, run mlx_generate, record callbacks."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load_model
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.cache import encode_prompt
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
build_model_path,
|
||||
get_tokenizer,
|
||||
)
|
||||
|
||||
model_path = build_model_path(ModelId(MODEL_ID))
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
mx.eval(model)
|
||||
|
||||
# Use PipelineShardMetadata just for get_tokenizer (needs model_card), but
|
||||
# do NOT apply pipeline sharding — the model keeps all layers unwrapped.
|
||||
dummy_meta = PipelineShardMetadata(
|
||||
model_card=_model_card(),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=TOTAL_LAYERS,
|
||||
n_layers=TOTAL_LAYERS,
|
||||
)
|
||||
tokenizer = get_tokenizer(model_path, dummy_meta)
|
||||
|
||||
prompt, task = _build_prompt(tokenizer, prompt_tokens)
|
||||
|
||||
callbacks: list[tuple[int, int]] = []
|
||||
|
||||
def on_progress(processed: int, total: int) -> None:
|
||||
callbacks.append((processed, total))
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=None,
|
||||
on_prefill_progress=on_progress,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
# Also record the token count that prefill() received (prompt_tokens[:-1])
|
||||
all_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefill_token_count = len(all_tokens) - 1
|
||||
|
||||
result_queue.put(
|
||||
(
|
||||
True,
|
||||
{
|
||||
"callbacks": callbacks,
|
||||
"text": generated_text,
|
||||
"prefill_token_count": prefill_token_count,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((False, f"{e}\n{traceback.format_exc()}"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline device process: uses _pipeline_prefill_cache path
|
||||
# ---------------------------------------------------------------------------
|
||||
def _run_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
result_queue: Any,
|
||||
) -> None:
|
||||
"""Load model with pipeline sharding, run mlx_generate, record callbacks."""
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.cache import encode_prompt
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_card=_model_card(),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=TOTAL_LAYERS,
|
||||
)
|
||||
|
||||
model, tokenizer = shard_and_load(
|
||||
shard_meta, group, on_timeout=None, on_layer_loaded=None
|
||||
)
|
||||
model = cast(Any, model)
|
||||
|
||||
prompt, task = _build_prompt(tokenizer, prompt_tokens)
|
||||
|
||||
callbacks: list[tuple[int, int]] = []
|
||||
|
||||
def on_progress(processed: int, total: int) -> None:
|
||||
callbacks.append((processed, total))
|
||||
|
||||
def distributed_prompt_progress_callback(_group: Any = group) -> None:
|
||||
from exo.worker.engines.mlx.utils_mlx import mx_any
|
||||
|
||||
mx_any(False, _group)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=None,
|
||||
group=group,
|
||||
on_prefill_progress=on_progress,
|
||||
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
all_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefill_token_count = len(all_tokens) - 1
|
||||
|
||||
result_queue.put(
|
||||
(
|
||||
rank,
|
||||
True,
|
||||
{
|
||||
"callbacks": callbacks,
|
||||
"text": generated_text,
|
||||
"prefill_token_count": prefill_token_count,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _create_hostfile(world_size: int, base_port: int) -> str:
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
return f.name
|
||||
|
||||
|
||||
def _run_single_device_test(prompt_tokens: int, timeout: int = 120) -> dict[str, Any]:
|
||||
"""Run single-device (stream_generate) prefill and return results."""
|
||||
ctx = mp.get_context("spawn")
|
||||
result_queue: Any = ctx.Queue()
|
||||
|
||||
p = ctx.Process(target=_run_single_device, args=(prompt_tokens, result_queue))
|
||||
p.start()
|
||||
p.join(timeout=timeout)
|
||||
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=5)
|
||||
pytest.fail("Single-device process timed out")
|
||||
|
||||
assert not result_queue.empty(), "Single-device process produced no result"
|
||||
success, data = result_queue.get()
|
||||
assert success, f"Single-device process failed:\n{data}"
|
||||
return data
|
||||
|
||||
|
||||
def _run_pipeline_test(
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
base_port: int,
|
||||
timeout: int = 120,
|
||||
) -> dict[int, dict[str, Any]]:
|
||||
"""Run pipeline prefill across ranks and return per-rank results."""
|
||||
world_size = len(layer_splits)
|
||||
hostfile_path = _create_hostfile(world_size, base_port)
|
||||
ctx = mp.get_context("spawn")
|
||||
result_queue: Any = ctx.Queue()
|
||||
|
||||
try:
|
||||
processes: list[Any] = []
|
||||
for rank in range(world_size):
|
||||
p = ctx.Process(
|
||||
target=_run_pipeline_device,
|
||||
args=(
|
||||
rank,
|
||||
world_size,
|
||||
hostfile_path,
|
||||
layer_splits,
|
||||
prompt_tokens,
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join(timeout=timeout)
|
||||
|
||||
timed_out = any(p.is_alive() for p in processes)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=5)
|
||||
|
||||
assert not timed_out, "Pipeline processes timed out"
|
||||
|
||||
results: dict[int, dict[str, Any]] = {}
|
||||
while not result_queue.empty():
|
||||
rank, success, data = result_queue.get()
|
||||
assert success, f"Pipeline rank {rank} failed:\n{data}"
|
||||
results[rank] = data
|
||||
|
||||
assert len(results) == world_size, (
|
||||
f"Expected {world_size} results, got {len(results)}: missing ranks {set(range(world_size)) - results.keys()}"
|
||||
)
|
||||
return results
|
||||
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not MODEL_PATH.exists(),
|
||||
reason=f"GPT-OSS model not found at {MODEL_PATH}",
|
||||
),
|
||||
]
|
||||
|
||||
LAYER_SPLITS_4WAY: list[tuple[int, int]] = [(0, 6), (6, 12), (12, 18), (18, 24)]
|
||||
LAYER_SPLITS_2WAY: list[tuple[int, int]] = [(0, 12), (12, 24)]
|
||||
|
||||
|
||||
class TestPipelineNoDeadlock:
|
||||
"""Pipeline prefill must not deadlock at any rank count or prompt length."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_splits,prompt_tokens",
|
||||
[
|
||||
(LAYER_SPLITS_2WAY, 128),
|
||||
(LAYER_SPLITS_2WAY, 4096),
|
||||
(LAYER_SPLITS_2WAY, 8192),
|
||||
(LAYER_SPLITS_2WAY, 16384),
|
||||
(LAYER_SPLITS_4WAY, 128),
|
||||
(LAYER_SPLITS_4WAY, 4096),
|
||||
(LAYER_SPLITS_4WAY, 8192),
|
||||
(LAYER_SPLITS_4WAY, 16384),
|
||||
],
|
||||
ids=[
|
||||
"2rank_128tok",
|
||||
"2rank_4096tok",
|
||||
"2rank_8192tok",
|
||||
"2rank_16384tok",
|
||||
"4rank_128tok",
|
||||
"4rank_4096tok",
|
||||
"4rank_8192tok",
|
||||
"4rank_16384tok",
|
||||
],
|
||||
)
|
||||
def test_no_deadlock(
|
||||
self,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
"""Pipeline must complete without deadlock at various prompt lengths."""
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=layer_splits,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29650,
|
||||
timeout=60,
|
||||
)
|
||||
# If we get here, no deadlock. Verify all ranks produced output.
|
||||
for rank, pipe_data in sorted(pipeline_results.items()):
|
||||
assert pipe_data["text"], f"Rank {rank} produced no output text"
|
||||
|
||||
|
||||
class TestPipelinePrefillCallbacks:
|
||||
"""Verify that pipeline prefill callbacks exactly match stream_generate callbacks."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_tokens",
|
||||
[50, 500, 5000],
|
||||
ids=["short_50", "medium_500", "long_5000"],
|
||||
)
|
||||
def test_callbacks_match(self, prompt_tokens: int) -> None:
|
||||
"""All pipeline ranks must produce identical callback sequences."""
|
||||
# Run 4-rank pipeline
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=LAYER_SPLITS_4WAY,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29700,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
# All ranks must agree on prefill token count and callback sequence
|
||||
rank0_data = pipeline_results[0]
|
||||
rank0_callbacks = rank0_data["callbacks"]
|
||||
prefill_count = rank0_data["prefill_token_count"]
|
||||
|
||||
for rank, pipe_data in sorted(pipeline_results.items()):
|
||||
pipe_callbacks = pipe_data["callbacks"]
|
||||
|
||||
assert pipe_data["prefill_token_count"] == prefill_count, (
|
||||
f"Rank {rank} prefill token count mismatch: "
|
||||
f"{pipe_data['prefill_token_count']} vs {prefill_count}"
|
||||
)
|
||||
|
||||
assert pipe_callbacks == rank0_callbacks, (
|
||||
f"Rank {rank} callback mismatch for {prompt_tokens} prompt tokens "
|
||||
f"(prefill M={prefill_count}):\n"
|
||||
f" pipeline R0 ({len(rank0_callbacks)} callbacks): {rank0_callbacks}\n"
|
||||
f" pipeline R{rank} ({len(pipe_callbacks)} callbacks): {pipe_callbacks}"
|
||||
)
|
||||
|
||||
# Structural checks: starts with (0, M), ends with (M, M), monotonically increasing
|
||||
assert rank0_callbacks[0] == (0, prefill_count), (
|
||||
f"First callback should be (0, {prefill_count}), got {rank0_callbacks[0]}"
|
||||
)
|
||||
assert rank0_callbacks[-1] == (prefill_count, prefill_count), (
|
||||
f"Last callback should be ({prefill_count}, {prefill_count}), got {rank0_callbacks[-1]}"
|
||||
)
|
||||
for i in range(1, len(rank0_callbacks)):
|
||||
assert rank0_callbacks[i][0] >= rank0_callbacks[i - 1][0], (
|
||||
f"Callbacks not monotonically increasing at index {i}: {rank0_callbacks}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_tokens",
|
||||
[50, 500],
|
||||
ids=["short_50", "medium_500"],
|
||||
)
|
||||
def test_output_matches(self, prompt_tokens: int) -> None:
|
||||
"""Pipeline-generated text must match single-device output."""
|
||||
single = _run_single_device_test(prompt_tokens, timeout=180)
|
||||
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=LAYER_SPLITS_4WAY,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29800,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
single_text = single["text"]
|
||||
|
||||
# The last rank produces the final logits, so its output should match.
|
||||
# Due to SDPA tiling non-determinism, allow minor differences in text.
|
||||
last_rank = max(pipeline_results.keys())
|
||||
pipe_text = pipeline_results[last_rank]["text"]
|
||||
|
||||
# For deterministic sampling (temp=0.0), outputs should match exactly
|
||||
# or be very close. Log both for debugging even if they match.
|
||||
if single_text != pipe_text:
|
||||
# Find first divergence point
|
||||
min_len = min(len(single_text), len(pipe_text))
|
||||
diverge_idx = next(
|
||||
(i for i in range(min_len) if single_text[i] != pipe_text[i]),
|
||||
min_len,
|
||||
)
|
||||
pytest.fail(
|
||||
f"Output text diverged at character {diverge_idx} for {prompt_tokens} prompt tokens:\n"
|
||||
f" single-device: {single_text!r}\n"
|
||||
f" pipeline R{last_rank}: {pipe_text!r}"
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineCallbacksStructure:
|
||||
"""Verify structural properties of callbacks independent of model output."""
|
||||
|
||||
def test_callback_structure_matches_generate_step(self) -> None:
|
||||
"""Verify callbacks follow generate_step's pattern: (0,M), chunks up to M-1, (M,M)."""
|
||||
prompt_tokens = 200
|
||||
pipeline_results = _run_pipeline_test(
|
||||
layer_splits=LAYER_SPLITS_4WAY,
|
||||
prompt_tokens=prompt_tokens,
|
||||
base_port=29900,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
for rank, pipe_data in sorted(pipeline_results.items()):
|
||||
callbacks = pipe_data["callbacks"]
|
||||
m = pipe_data["prefill_token_count"]
|
||||
assert m > 0, f"Rank {rank}: prefill token count is 0"
|
||||
|
||||
assert callbacks[0] == (0, m), (
|
||||
f"Rank {rank}: first callback should be (0, {m}), got {callbacks[0]}"
|
||||
)
|
||||
|
||||
assert callbacks[-1] == (m, m), (
|
||||
f"Rank {rank}: last callback should be ({m}, {m}), got {callbacks[-1]}"
|
||||
)
|
||||
|
||||
if len(callbacks) > 2:
|
||||
second_to_last = callbacks[-2]
|
||||
assert second_to_last[0] < m, (
|
||||
f"Rank {rank}: second-to-last callback should report < {m}, "
|
||||
f"got {second_to_last}"
|
||||
)
|
||||
|
||||
# All callbacks must have total == M
|
||||
for i, (_, total) in enumerate(callbacks):
|
||||
assert total == m, (
|
||||
f"Rank {rank}: callback {i} has total={total}, expected {m}"
|
||||
)
|
||||
|
||||
# processed values must be non-decreasing
|
||||
processed_vals = [p for p, _ in callbacks]
|
||||
for i in range(1, len(processed_vals)):
|
||||
assert processed_vals[i] >= processed_vals[i - 1], (
|
||||
f"Rank {rank}: callbacks not non-decreasing at index {i}: "
|
||||
f"{processed_vals}"
|
||||
)
|
||||
|
||||
# No duplicate consecutive callbacks (pipeline dummies must not emit callbacks)
|
||||
for i in range(1, len(callbacks)):
|
||||
assert callbacks[i] != callbacks[i - 1], (
|
||||
f"Rank {rank}: duplicate consecutive callback at index {i}: "
|
||||
f"{callbacks[i]} (this suggests dummy iterations are emitting callbacks)"
|
||||
)
|
||||
@@ -15,8 +15,8 @@ from mlx.utils import tree_flatten, tree_unflatten
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
|
||||
Reference in New Issue
Block a user