Compare commits

...

9 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
1317354368 address comments 2026-02-23 09:53:06 +00:00
Ryuichi Leo Takashige
ca15f084c0 Format mlx typings oops 2026-02-22 01:53:34 +00:00
Ryuichi Leo Takashige
1097d3e3dd Fix maybe_qunatize_kv_cache typing 2026-02-22 01:32:16 +00:00
rltakashige
744cc0225a Merge branch 'main' into leo/add-custom-prefill-for-pipeline 2026-02-22 01:27:17 +00:00
Ryuichi Leo Takashige
a46330dd09 remove unnecessary stuff 2026-02-22 01:22:55 +00:00
Ryuichi Leo Takashige
951fa7b270 remove debug logs 2026-02-22 01:01:08 +00:00
Ryuichi Leo Takashige
20551d9333 add debug logs 2026-02-22 00:32:15 +00:00
Ryuichi Leo Takashige
af0ac9fbf8 separate callbacks and add test 2026-02-22 00:17:58 +00:00
Ryuichi Leo Takashige
3954ebf435 add custom prefill for pipeline parallel models 2026-02-21 22:01:39 +00:00
13 changed files with 749 additions and 85 deletions

View File

@@ -73,9 +73,11 @@ class GenerationResponse:
finish_reason: Optional[str] = ...
def maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
): # -> None:
...
prompt_cache: Any,
quantized_kv_start: int | None,
kv_group_size: int | None,
kv_bits: int | None,
) -> None: ...
def generate_step(
prompt: mx.array,
model: nn.Module,

View File

@@ -16,7 +16,7 @@ class Cache(Protocol):
self, keys: mx.array, values: mx.array
) -> tuple[mx.array, mx.array]: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@@ -92,13 +92,14 @@ class _BaseCache(Cache):
values: mx.array
offset: int
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@property
def meta_state(self) -> Literal[""]: ...
@meta_state.setter
def meta_state(self, v) -> None: ...
def trim(self, n: int) -> int: ...
def is_trimmable(self) -> Literal[False]: ...
@classmethod
def from_state(cls, state, meta_state) -> Self: ...
@@ -114,15 +115,13 @@ class ConcatenateKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
...
@property
def state(self): # -> tuple[Any | array | None, Any | array | None]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -132,10 +131,7 @@ class QuantizedKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> Any:
...
@property
def state(
self,
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -147,8 +143,7 @@ class QuantizedKVCache(_BaseCache):
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -160,13 +155,12 @@ class KVCache(_BaseCache):
@property
def state(
self,
) -> tuple[array, array]: ...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -183,8 +177,7 @@ class RotatingKVCache(_BaseCache):
@property
def state(
self,
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -196,8 +189,7 @@ class RotatingKVCache(_BaseCache):
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -212,8 +204,7 @@ class ArraysCache(_BaseCache):
...
def __getitem__(self, idx): ...
@property
def state(self): # -> list[Any | array] | list[array]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -239,8 +230,7 @@ class ChunkedKVCache(KVCache):
...
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@@ -253,10 +243,9 @@ class CacheList(_BaseCache):
def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool:
...
def trim(self, n): ...
def trim(self, n: int) -> int: ...
@property
def state(self): # -> list[Any]:
...
def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...
@state.setter
def state(self, v): # -> None:
...

View File

@@ -2,6 +2,8 @@
from collections.abc import Sequence
from mlx import core as mx
from mlx import nn as nn
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
@@ -14,3 +16,16 @@ from mlx_lm.models.cache import (
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]
# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: KVCacheType | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -1,17 +0,0 @@
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -49,6 +49,21 @@ if TYPE_CHECKING:
TimeoutCallback = Callable[[], None]
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
def flush_prefill_sends() -> None:
for output, dst, group in _pending_prefill_sends:
sent = mx.distributed.send(output, dst, group=group)
mx.async_eval(sent)
_pending_prefill_sends.clear()
def clear_prefill_sends() -> None:
# Discard pending sends (e.g. on cancellation).
_pending_prefill_sends.clear()
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
@@ -159,18 +174,21 @@ class PipelineLastLayer(CustomMlxLayer):
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
if self.is_prefill:
mx.eval(output)
if cache is not None:
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
mx.eval(_cache.keys) # type: ignore
_pending_prefill_sends.append(
(output, (self.r + 1) % self.s, self.group)
)
else:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
if not self.is_prefill:
output = mx.distributed.all_gather(output, group=self.group)[

View File

@@ -13,8 +13,7 @@ from mlx_lm.models.cache import (
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.shared.types.mlx import KVCacheType, Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
@@ -254,9 +253,9 @@ def trim_cache(
if snapshot is not None and snapshot.states[i] is not None:
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
else:
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
c.state = [None] * len(c.state)
else:
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
c.trim(num_tokens)
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:

View File

@@ -1,9 +1,13 @@
import functools
import time
from copy import deepcopy
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.generate import (
maybe_quantize_kv_cache,
stream_generate,
)
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -18,13 +22,18 @@ from exo.shared.types.api import (
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.mlx import KVCacheType, Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill
from exo.worker.engines.mlx.auto_parallel import (
PipelineFirstLayer,
PipelineLastLayer,
clear_prefill_sends,
flush_prefill_sends,
set_pipeline_prefill,
)
from exo.worker.engines.mlx.cache import (
CacheSnapshot,
KVPrefixCache,
@@ -55,6 +64,127 @@ class PrefillCancelled(BaseException):
"""Raised when prefill is cancelled via the progress callback."""
def _has_pipeline_communication_layer(model: Model):
for layer in model.layers:
if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):
return True
return False
def pipeline_parallel_prefill(
model: Model,
prompt: mx.array,
prompt_cache: KVCacheType,
prefill_step_size: int,
kv_group_size: int | None,
kv_bits: int | None,
prompt_progress_callback: Callable[[int, int], None],
distributed_prompt_progress_callback: Callable[[], None] | None,
group: mx.distributed.Group,
) -> None:
"""Prefill the KV cache for pipeline parallel with overlapping stages.
Each rank processes the full prompt through its real cache, offset by leading
and trailing dummy iterations.
Total iterations per rank = N_real_chunks + world_size - 1:
- rank r leading dummies (skip_pipeline_io, throwaway cache)
- N_real_chunks real (pipeline IO active, real cache)
- (world_size-1-r) trailing dummies (skip_pipeline_io, throwaway cache)
e.g.
Timeline (2 ranks, 3 chunks of 10240 tokens @ step=4096):
iter 0: R0 real[0:4096] R1 dummy
iter 1: R0 real[4096:8192] R1 real[0:4096]
iter 2: R0 real[8192:10240] R1 real[4096:8192]
iter 3: R0 dummy R1 real[8192:10240]
This function is designed to match mlx_lm's stream_generate exactly in terms of side effects.
"""
quantize_cache_fn: Callable[..., None] = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=0,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
_prompt_cache: KVCacheType = prompt_cache
rank = group.rank()
world_size = group.size()
# Build list of real prompt chunk sizes
total = len(prompt)
real_chunk_sizes: list[int] = []
remaining = total - 1
while remaining:
n = min(prefill_step_size, remaining)
real_chunk_sizes.append(n)
remaining -= n
n_real = len(real_chunk_sizes)
# Each rank does: [rank leading dummies] [N real chunks] [world_size-1-rank trailing dummies]
n_leading = rank
n_trailing = world_size - 1 - rank
n_total = n_leading + n_real + n_trailing
t_start = time.perf_counter()
processed = 0
logger.info(
f"[R{rank}] Pipeline prefill: {n_real} real + {n_leading} leading + {n_trailing} trailing = {n_total} iterations"
)
clear_prefill_sends()
# Initial callback matching generate_step
prompt_progress_callback(0, total)
try:
with mx.stream(generation_stream):
for _ in range(n_leading):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
for i in range(n_real):
chunk_size = real_chunk_sizes[i]
model(
prompt[processed : processed + chunk_size][None],
cache=_prompt_cache,
)
quantize_cache_fn(_prompt_cache)
processed += chunk_size
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
flush_prefill_sends()
prompt_progress_callback(processed, total)
for _ in range(n_leading):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
finally:
clear_prefill_sends()
# Post-loop: process remaining 1 token + add +1 entry to match stream_generate.
for _ in range(2):
with mx.stream(generation_stream):
model(prompt[-1:][None], cache=_prompt_cache)
quantize_cache_fn(_prompt_cache)
flush_prefill_sends()
assert _prompt_cache is not None
mx.eval([c.state for c in _prompt_cache]) # type: ignore
# Final callback matching generate_step
prompt_progress_callback(total, total)
logger.info(
f"[R{rank}] Prefill: {n_real} real + {n_leading}+{n_trailing} dummy iterations, "
f"Processed {processed} tokens in {(time.perf_counter() - t_start) * 1000:.1f}ms"
)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
@@ -63,6 +193,7 @@ def prefill(
cache: KVCacheType,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None,
distributed_prompt_progress_callback: Callable[[], None] | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -94,27 +225,48 @@ def prefill(
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
def combined_progress_callback(processed: int, total: int) -> None:
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
progress_callback(processed, total)
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
logger.info("Starting prefill")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
is_pipeline = _has_pipeline_communication_layer(model)
try:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
if is_pipeline:
assert group is not None, "Pipeline prefill requires a distributed group"
pipeline_parallel_prefill(
model=model,
prompt=prompt_tokens,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
group=group,
)
else:
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=combined_progress_callback,
):
break # Stop after first iteration - cache is now filled
except PrefillCancelled:
set_pipeline_prefill(model, is_prefill=False)
raise
@@ -131,7 +283,7 @@ def prefill(
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
else:
assert not isinstance(c, (ArraysCache, RotatingKVCache))
c.trim(2) # pyright: ignore[reportUnknownMemberType]
c.trim(2)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
@@ -271,6 +423,7 @@ def mlx_generate(
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -332,6 +485,7 @@ def mlx_generate(
caches,
group,
on_prefill_progress,
distributed_prompt_progress_callback,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None

View File

@@ -41,6 +41,7 @@ from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -53,7 +54,6 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
eval_with_timeout,

View File

@@ -31,6 +31,7 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -63,7 +64,6 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
@@ -255,8 +255,6 @@ def main(
def on_prefill_progress(
processed: int,
total: int,
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
if device_rank == 0:
event_sender.send(
@@ -269,6 +267,11 @@ def main(
),
)
)
def distributed_prompt_progress_callback(
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (_task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
@@ -290,6 +293,7 @@ def main(
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
group=group,
)

View File

@@ -14,9 +14,9 @@ from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load

View File

@@ -9,8 +9,8 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
cache_length,

View File

@@ -0,0 +1,500 @@
# type: ignore
"""Test that pipeline prefill callbacks and output exactly match stream_generate.
Spins up a single-device (non-pipeline) run and a distributed pipeline run,
then verifies that the prompt_progress_callback sequences are identical
and that generated text matches.
"""
import json
import multiprocessing as mp
import os
import tempfile
import traceback
from typing import Any, cast
import pytest
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
MODEL_PATH = EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8"
TOTAL_LAYERS = 24
MAX_TOKENS = 10
SEED = 42
TEMPERATURE = 0.0
def _model_card() -> ModelCard:
return ModelCard(
model_id=ModelId(MODEL_ID),
storage_size=Memory.from_gb(12),
n_layers=TOTAL_LAYERS,
hidden_size=2880,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
)
def _build_prompt(tokenizer: Any, prompt_tokens: int) -> tuple[str, Any]:
"""Build a prompt with the given number of user-content tokens, return (chat_prompt, task)."""
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
base_text = "The quick brown fox jumps over the lazy dog. "
base_toks = tokenizer.encode(base_text)
repeats = (prompt_tokens // len(base_toks)) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = TextGenerationTaskParams(
model=MODEL_ID,
input=[InputMessage(role="user", content=prompt_text)],
max_output_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
seed=SEED,
)
prompt = apply_chat_template(tokenizer, task)
return prompt, task
# ---------------------------------------------------------------------------
# Single-device process: uses stream_generate path (no pipeline layers)
# ---------------------------------------------------------------------------
def _run_single_device(
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load full model without pipeline sharding, run mlx_generate, record callbacks."""
try:
import mlx.core as mx
from mlx_lm.utils import load_model
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
build_model_path,
get_tokenizer,
)
model_path = build_model_path(ModelId(MODEL_ID))
model, _ = load_model(model_path, lazy=True, strict=False)
mx.eval(model)
# Use PipelineShardMetadata just for get_tokenizer (needs model_card), but
# do NOT apply pipeline sharding — the model keeps all layers unwrapped.
dummy_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=TOTAL_LAYERS,
n_layers=TOTAL_LAYERS,
)
tokenizer = get_tokenizer(model_path, dummy_meta)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=None,
on_prefill_progress=on_progress,
):
generated_text += response.text
if response.finish_reason is not None:
break
# Also record the token count that prefill() received (prompt_tokens[:-1])
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Pipeline device process: uses _pipeline_prefill_cache path
# ---------------------------------------------------------------------------
def _run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load model with pipeline sharding, run mlx_generate, record callbacks."""
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
try:
import mlx.core as mx
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
group = mx.distributed.init(backend="ring", strict=True)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=TOTAL_LAYERS,
)
model, tokenizer = shard_and_load(shard_meta, group)
model = cast(Any, model)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
def distributed_prompt_progress_callback(_group: Any = group) -> None:
from exo.worker.engines.mlx.utils_mlx import mx_any
mx_any(False, _group)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
on_prefill_progress=on_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
):
generated_text += response.text
if response.finish_reason is not None:
break
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
rank,
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
def _create_hostfile(world_size: int, base_port: int) -> str:
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
return f.name
def _run_single_device_test(prompt_tokens: int, timeout: int = 120) -> dict[str, Any]:
"""Run single-device (stream_generate) prefill and return results."""
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
p = ctx.Process(target=_run_single_device, args=(prompt_tokens, result_queue))
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.terminate()
p.join(timeout=5)
pytest.fail("Single-device process timed out")
assert not result_queue.empty(), "Single-device process produced no result"
success, data = result_queue.get()
assert success, f"Single-device process failed:\n{data}"
return data
def _run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
base_port: int,
timeout: int = 120,
) -> dict[int, dict[str, Any]]:
"""Run pipeline prefill across ranks and return per-rank results."""
world_size = len(layer_splits)
hostfile_path = _create_hostfile(world_size, base_port)
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
try:
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=_run_pipeline_device,
args=(
rank,
world_size,
hostfile_path,
layer_splits,
prompt_tokens,
result_queue,
),
)
p.start()
processes.append(p)
for p in processes:
p.join(timeout=timeout)
timed_out = any(p.is_alive() for p in processes)
for p in processes:
if p.is_alive():
p.terminate()
p.join(timeout=5)
assert not timed_out, "Pipeline processes timed out"
results: dict[int, dict[str, Any]] = {}
while not result_queue.empty():
rank, success, data = result_queue.get()
assert success, f"Pipeline rank {rank} failed:\n{data}"
results[rank] = data
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}: missing ranks {set(range(world_size)) - results.keys()}"
)
return results
finally:
os.unlink(hostfile_path)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not MODEL_PATH.exists(),
reason=f"GPT-OSS model not found at {MODEL_PATH}",
),
]
LAYER_SPLITS_4WAY: list[tuple[int, int]] = [(0, 6), (6, 12), (12, 18), (18, 24)]
LAYER_SPLITS_2WAY: list[tuple[int, int]] = [(0, 12), (12, 24)]
class TestPipelineNoDeadlock:
"""Pipeline prefill must not deadlock at any rank count or prompt length."""
@pytest.mark.parametrize(
"layer_splits,prompt_tokens",
[
(LAYER_SPLITS_2WAY, 128),
(LAYER_SPLITS_2WAY, 4096),
(LAYER_SPLITS_2WAY, 8192),
(LAYER_SPLITS_2WAY, 16384),
(LAYER_SPLITS_4WAY, 128),
(LAYER_SPLITS_4WAY, 4096),
(LAYER_SPLITS_4WAY, 8192),
(LAYER_SPLITS_4WAY, 16384),
],
ids=[
"2rank_128tok",
"2rank_4096tok",
"2rank_8192tok",
"2rank_16384tok",
"4rank_128tok",
"4rank_4096tok",
"4rank_8192tok",
"4rank_16384tok",
],
)
def test_no_deadlock(
self,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
) -> None:
"""Pipeline must complete without deadlock at various prompt lengths."""
pipeline_results = _run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=prompt_tokens,
base_port=29650,
timeout=60,
)
# If we get here, no deadlock. Verify all ranks produced output.
for rank, pipe_data in sorted(pipeline_results.items()):
assert pipe_data["text"], f"Rank {rank} produced no output text"
class TestPipelinePrefillCallbacks:
"""Verify that pipeline prefill callbacks exactly match stream_generate callbacks."""
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500, 5000],
ids=["short_50", "medium_500", "long_5000"],
)
def test_callbacks_match(self, prompt_tokens: int) -> None:
"""Pipeline and stream_generate must produce identical callback sequences."""
# Run single-device (stream_generate path)
single = _run_single_device_test(prompt_tokens, timeout=180)
# Run 4-rank pipeline
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29700,
timeout=180,
)
single_callbacks = single["callbacks"]
prefill_count = single["prefill_token_count"]
# Every rank must produce the same callback sequence as stream_generate
for rank, pipe_data in sorted(pipeline_results.items()):
pipe_callbacks = pipe_data["callbacks"]
assert pipe_data["prefill_token_count"] == prefill_count, (
f"Rank {rank} prefill token count mismatch: "
f"{pipe_data['prefill_token_count']} vs {prefill_count}"
)
assert pipe_callbacks == single_callbacks, (
f"Rank {rank} callback mismatch for {prompt_tokens} prompt tokens "
f"(prefill M={prefill_count}):\n"
f" stream_generate ({len(single_callbacks)} callbacks): {single_callbacks}\n"
f" pipeline R{rank} ({len(pipe_callbacks)} callbacks): {pipe_callbacks}"
)
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500],
ids=["short_50", "medium_500"],
)
def test_output_matches(self, prompt_tokens: int) -> None:
"""Pipeline-generated text must match single-device output."""
single = _run_single_device_test(prompt_tokens, timeout=180)
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29800,
timeout=180,
)
single_text = single["text"]
# The last rank produces the final logits, so its output should match.
# Due to SDPA tiling non-determinism, allow minor differences in text.
last_rank = max(pipeline_results.keys())
pipe_text = pipeline_results[last_rank]["text"]
# For deterministic sampling (temp=0.0), outputs should match exactly
# or be very close. Log both for debugging even if they match.
if single_text != pipe_text:
# Find first divergence point
min_len = min(len(single_text), len(pipe_text))
diverge_idx = next(
(i for i in range(min_len) if single_text[i] != pipe_text[i]),
min_len,
)
pytest.fail(
f"Output text diverged at character {diverge_idx} for {prompt_tokens} prompt tokens:\n"
f" single-device: {single_text!r}\n"
f" pipeline R{last_rank}: {pipe_text!r}"
)
class TestPipelineCallbacksStructure:
"""Verify structural properties of callbacks independent of model output."""
def test_callback_structure_matches_generate_step(self) -> None:
"""Verify callbacks follow generate_step's pattern: (0,M), chunks up to M-1, (M,M)."""
prompt_tokens = 200
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29900,
timeout=180,
)
for rank, pipe_data in sorted(pipeline_results.items()):
callbacks = pipe_data["callbacks"]
m = pipe_data["prefill_token_count"]
assert m > 0, f"Rank {rank}: prefill token count is 0"
assert callbacks[0] == (0, m), (
f"Rank {rank}: first callback should be (0, {m}), got {callbacks[0]}"
)
assert callbacks[-1] == (m, m), (
f"Rank {rank}: last callback should be ({m}, {m}), got {callbacks[-1]}"
)
if len(callbacks) > 2:
second_to_last = callbacks[-2]
assert second_to_last[0] < m, (
f"Rank {rank}: second-to-last callback should report < {m}, "
f"got {second_to_last}"
)
# All callbacks must have total == M
for i, (_, total) in enumerate(callbacks):
assert total == m, (
f"Rank {rank}: callback {i} has total={total}, expected {m}"
)
# processed values must be non-decreasing
processed_vals = [p for p, _ in callbacks]
for i in range(1, len(processed_vals)):
assert processed_vals[i] >= processed_vals[i - 1], (
f"Rank {rank}: callbacks not non-decreasing at index {i}: "
f"{processed_vals}"
)
# No duplicate consecutive callbacks (pipeline dummies must not emit callbacks)
for i in range(1, len(callbacks)):
assert callbacks[i] != callbacks[i - 1], (
f"Rank {rank}: duplicate consecutive callback at index {i}: "
f"{callbacks[i]} (this suggests dummy iterations are emitting callbacks)"
)

View File

@@ -15,8 +15,8 @@ from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (