Compare commits

...

3 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
cdf721e6ad Refactor runner for implementing batching 2026-02-27 16:45:39 +00:00
rltakashige
152a27ea5d Fix pipeline mismatched send after 1587 (#1629)
## Motivation

Tests caught a bug. It was a real bug.
2026-02-26 16:48:34 +00:00
rltakashige
db36bd5ac6 Add custom prefill for pipeline (#1587)
## Motivation

Since we need to do distributed communications between prefill step
sizes, the out-of-the-box stream_generate that we currently use prevents
pipeline parallel models from doing overlapped computation. While this
was technically a regression, this communication is necessary for
cancellation, and we will need various distributed communications in the
future (e.g. for coordinating batching).

500 lines are for one testing file, so the diffs aren't as bad as they
look!

## Changes

Added a special prefill function for pipeline parallel models
Edited the model to handle 
Added a test to verify this new prefill and the original prefill produce
identical results
Improved type stubs to remove some type: ignores 

## Why It Works
<img width="768" height="1246" alt="image"
src="https://github.com/user-attachments/assets/8986ff17-ac23-4a02-9bd7-e6253a0ca799"
/>

## Test Plan

### Manual Testing
Needs more testing, but seems good so far.

### Automated Testing
Passes CI, considerable speedup seen in benchmarks (up to 1.98x) on
prefill speed.

Before:
<img width="3280" height="1238" alt="image"
src="https://github.com/user-attachments/assets/9abc1cbc-ecdb-4e48-a675-2c4cb04a32a0"
/>


After:
<img width="3344" height="1236" alt="image"
src="https://github.com/user-attachments/assets/e03c7987-41b4-4950-9ac3-2840e774ce30"
/>
2026-02-26 16:00:38 +00:00
26 changed files with 1785 additions and 771 deletions

View File

@@ -73,9 +73,11 @@ class GenerationResponse:
finish_reason: Optional[str] = ...
def maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
): # -> None:
...
prompt_cache: Any,
quantized_kv_start: int | None,
kv_group_size: int | None,
kv_bits: int | None,
) -> None: ...
def generate_step(
prompt: mx.array,
model: nn.Module,

View File

@@ -16,7 +16,7 @@ class Cache(Protocol):
self, keys: mx.array, values: mx.array
) -> tuple[mx.array, mx.array]: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@@ -92,13 +92,14 @@ class _BaseCache(Cache):
values: mx.array
offset: int
@property
def state(self) -> tuple[mx.array, mx.array]: ...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
@property
def meta_state(self) -> Literal[""]: ...
@meta_state.setter
def meta_state(self, v) -> None: ...
def trim(self, n: int) -> int: ...
def is_trimmable(self) -> Literal[False]: ...
@classmethod
def from_state(cls, state, meta_state) -> Self: ...
@@ -114,15 +115,13 @@ class ConcatenateKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
...
@property
def state(self): # -> tuple[Any | array | None, Any | array | None]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -132,10 +131,7 @@ class QuantizedKVCache(_BaseCache):
def update_and_fetch(self, keys, values): # -> Any:
...
@property
def state(
self,
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -147,8 +143,7 @@ class QuantizedKVCache(_BaseCache):
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
@@ -160,13 +155,12 @@ class KVCache(_BaseCache):
@property
def state(
self,
) -> tuple[array, array]: ...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v) -> None: ...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -183,8 +177,7 @@ class RotatingKVCache(_BaseCache):
@property
def state(
self,
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
...
) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -196,8 +189,7 @@ class RotatingKVCache(_BaseCache):
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
@@ -212,8 +204,7 @@ class ArraysCache(_BaseCache):
...
def __getitem__(self, idx): ...
@property
def state(self): # -> list[Any | array] | list[array]:
...
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
@state.setter
def state(self, v): # -> None:
...
@@ -239,8 +230,7 @@ class ChunkedKVCache(KVCache):
...
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
...
def trim(self, n): # -> int:
...
def trim(self, n: int) -> int: ...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@@ -253,10 +243,9 @@ class CacheList(_BaseCache):
def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool:
...
def trim(self, n): ...
def trim(self, n: int) -> int: ...
@property
def state(self): # -> list[Any]:
...
def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...
@state.setter
def state(self, v): # -> None:
...

View File

@@ -314,9 +314,13 @@ async def fetch_file_list_with_cache(
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
logger.opt(exception=e).warning(
"Ran into exception when fetching file list from HF."
)
if await aios.path.exists(cache_file):
logger.warning(
f"No internet and no cached file list for {model_id} - using local file list"
f"No cached file list for {model_id} - using local file list"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())

View File

@@ -258,6 +258,6 @@ def get_node_id_keypair(
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
keypair = Keypair.generate()
f.write(keypair.to_bytes())
return keypair

View File

@@ -2,6 +2,8 @@
from collections.abc import Sequence
from mlx import core as mx
from mlx import nn as nn
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
@@ -14,3 +16,16 @@ from mlx_lm.models.cache import (
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]
# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: KVCacheType | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -1,5 +1,6 @@
import contextlib
import multiprocessing as mp
from collections.abc import Generator
from dataclasses import dataclass, field
from math import inf
from multiprocessing.synchronize import Event
@@ -282,6 +283,54 @@ class MpReceiver[T]:
return d
class NonBlockingGenerator[T](Generator[T | None, None, None]):
def __init__(self, source: MpReceiver[T] | Generator[T | None, None, None]) -> None:
self._receiver: MpReceiver[T] | None = None
self._inner: Generator[T | None, None, None] | None = None
if isinstance(source, MpReceiver):
self._receiver = source
else:
self._inner = source
self._exhausted = False
def send(self, value: None, /) -> T | None:
if self._exhausted:
raise StopIteration
if self._inner is not None:
try:
return next(self._inner)
except (StopIteration, ClosedResourceError):
self._exhausted = True
raise StopIteration from None
assert self._receiver is not None
try:
return self._receiver.receive_nowait()
except WouldBlock:
return None
except (EndOfStream, ClosedResourceError):
self._exhausted = True
raise StopIteration from None
def throw(
self,
typ: type[BaseException] | BaseException,
val: BaseException | object = None,
tb: TracebackType | None = None,
/,
) -> T | None:
raise StopIteration
@property
def is_exhausted(self) -> bool:
return self._exhausted
def try_receive(self) -> T | None:
try:
return next(self)
except StopIteration:
return None
class channel[T]: # noqa: N801
"""Create a pair of asynchronous channels for communicating within the same process"""

View File

@@ -1,17 +0,0 @@
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...

View File

@@ -49,6 +49,21 @@ TimeoutCallback = Callable[[], None]
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
def flush_prefill_sends() -> None:
for output, dst, group in _pending_prefill_sends:
sent = mx.distributed.send(output, dst, group=group)
mx.async_eval(sent)
_pending_prefill_sends.clear()
def clear_prefill_sends() -> None:
# Discard pending sends (e.g. on cancellation).
_pending_prefill_sends.clear()
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
@@ -150,6 +165,7 @@ class PipelineLastLayer(CustomMlxLayer):
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
self.is_prefill: bool = False
self.queue_sends: bool = False
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
@@ -163,9 +179,14 @@ class PipelineLastLayer(CustomMlxLayer):
mx.eval(output)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if self.queue_sends:
_pending_prefill_sends.append(
(output, (self.r + 1) % self.s, self.group)
)
else:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
@@ -190,6 +211,12 @@ def set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None:
layer.is_prefill = is_prefill
def set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None:
for layer in model.layers: # type: ignore
if isinstance(layer, PipelineLastLayer):
layer.queue_sends = queue_sends
def get_inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
if isinstance(inner, nn.Module):

View File

@@ -13,8 +13,7 @@ from mlx_lm.models.cache import (
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.shared.types.mlx import KVCacheType, Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
@@ -254,9 +253,9 @@ def trim_cache(
if snapshot is not None and snapshot.states[i] is not None:
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
else:
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
c.state = [None] * len(c.state)
else:
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
c.trim(num_tokens)
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:

View File

@@ -1,10 +1,14 @@
import functools
import math
import time
from copy import deepcopy
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.generate import (
maybe_quantize_kv_cache,
stream_generate,
)
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -19,13 +23,19 @@ from exo.shared.types.api import (
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.mlx import KVCacheType, Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill
from exo.worker.engines.mlx.auto_parallel import (
PipelineFirstLayer,
PipelineLastLayer,
clear_prefill_sends,
flush_prefill_sends,
set_pipeline_prefill,
set_pipeline_queue_sends,
)
from exo.worker.engines.mlx.cache import (
CacheSnapshot,
KVPrefixCache,
@@ -56,6 +66,130 @@ class PrefillCancelled(BaseException):
"""Raised when prefill is cancelled via the progress callback."""
def _has_pipeline_communication_layer(model: Model):
for layer in model.layers:
if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):
return True
return False
def pipeline_parallel_prefill(
model: Model,
prompt: mx.array,
prompt_cache: KVCacheType,
prefill_step_size: int,
kv_group_size: int | None,
kv_bits: int | None,
prompt_progress_callback: Callable[[int, int], None],
distributed_prompt_progress_callback: Callable[[], None] | None,
group: mx.distributed.Group,
) -> None:
"""Prefill the KV cache for pipeline parallel with overlapping stages.
Each rank processes the full prompt through its real cache, offset by leading
and trailing dummy iterations.
Total iterations per rank = N_real_chunks + world_size - 1:
- rank r leading dummies (skip_pipeline_io, throwaway cache)
- N_real_chunks real (pipeline IO active, real cache)
- (world_size-1-r) trailing dummies (skip_pipeline_io, throwaway cache)
e.g.
Timeline (2 ranks, 3 chunks of 10240 tokens @ step=4096):
iter 0: R0 real[0:4096] R1 dummy
iter 1: R0 real[4096:8192] R1 real[0:4096]
iter 2: R0 real[8192:10240] R1 real[4096:8192]
iter 3: R0 dummy R1 real[8192:10240]
This function is designed to match mlx_lm's stream_generate exactly in terms of
side effects (given the same prefill step size)
"""
prefill_step_size = prefill_step_size // min(4, group.size())
quantize_cache_fn: Callable[..., None] = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=0,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
_prompt_cache: KVCacheType = prompt_cache
rank = group.rank()
world_size = group.size()
# Build list of real prompt chunk sizes
total = len(prompt)
real_chunk_sizes: list[int] = []
remaining = total - 1
while remaining:
n = min(prefill_step_size, remaining)
real_chunk_sizes.append(n)
remaining -= n
n_real = len(real_chunk_sizes)
# Each rank does: [rank leading dummies] [N real chunks] [world_size-1-rank trailing dummies]
n_leading = rank
n_trailing = world_size - 1 - rank
n_total = n_leading + n_real + n_trailing
t_start = time.perf_counter()
processed = 0
logger.info(
f"[R{rank}] Pipeline prefill: {n_real} real + {n_leading} leading + {n_trailing} trailing = {n_total} iterations"
)
clear_prefill_sends()
# Initial callback matching generate_step
prompt_progress_callback(0, total)
try:
with mx.stream(generation_stream):
for _ in range(n_leading):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
for i in range(n_real):
chunk_size = real_chunk_sizes[i]
model(
prompt[processed : processed + chunk_size][None],
cache=_prompt_cache,
)
quantize_cache_fn(_prompt_cache)
processed += chunk_size
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
flush_prefill_sends()
prompt_progress_callback(processed, total)
for _ in range(n_trailing):
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
finally:
clear_prefill_sends()
# Post-loop: process remaining 1 token + add +1 entry to match stream_generate.
for _ in range(2):
with mx.stream(generation_stream):
model(prompt[-1:][None], cache=_prompt_cache)
quantize_cache_fn(_prompt_cache)
flush_prefill_sends()
assert _prompt_cache is not None
mx.eval([c.state for c in _prompt_cache]) # type: ignore
# Final callback matching generate_step
prompt_progress_callback(total, total)
logger.info(
f"[R{rank}] Prefill: {n_real} real + {n_leading}+{n_trailing} dummy iterations, "
f"Processed {processed} tokens in {(time.perf_counter() - t_start) * 1000:.1f}ms"
)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
@@ -64,6 +198,7 @@ def prefill(
cache: KVCacheType,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None,
distributed_prompt_progress_callback: Callable[[], None] | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -95,31 +230,57 @@ def prefill(
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
def combined_progress_callback(processed: int, total: int) -> None:
if distributed_prompt_progress_callback is not None:
distributed_prompt_progress_callback()
progress_callback(processed, total)
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
logger.info("Starting prefill")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
is_pipeline = _has_pipeline_communication_layer(model)
prefill_step_size = 4096
try:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
if is_pipeline and num_tokens >= prefill_step_size:
set_pipeline_queue_sends(model, queue_sends=True)
assert group is not None, "Pipeline prefill requires a distributed group"
pipeline_parallel_prefill(
model=model,
prompt=prompt_tokens,
prompt_cache=cache,
prefill_step_size=prefill_step_size,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
group=group,
)
else:
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=prefill_step_size,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=combined_progress_callback,
):
break # Stop after first iteration - cache is now filled
except PrefillCancelled:
set_pipeline_queue_sends(model, queue_sends=False)
set_pipeline_prefill(model, is_prefill=False)
raise
set_pipeline_queue_sends(model, queue_sends=False)
set_pipeline_prefill(model, is_prefill=False)
# stream_generate added 1 extra generated token to the cache, so we should trim it.
@@ -132,7 +293,7 @@ def prefill(
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
else:
assert not isinstance(c, (ArraysCache, RotatingKVCache))
c.trim(2) # pyright: ignore[reportUnknownMemberType]
c.trim(2)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
@@ -275,6 +436,8 @@ def mlx_generate(
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None,
on_generation_token: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -336,6 +499,7 @@ def mlx_generate(
caches,
group,
on_prefill_progress,
distributed_prompt_progress_callback,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
@@ -481,6 +645,9 @@ def mlx_generate(
full_prompt_tokens, caches, cache_snapshots
)
if on_generation_token is not None:
on_generation_token()
yield GenerationResponse(
text=text,
token=out.token,

View File

@@ -40,6 +40,7 @@ from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -52,7 +53,6 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
LayerLoadedCallback,
TimeoutCallback,

View File

@@ -297,10 +297,10 @@ def _pending_tasks(
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
if task.task_id in runner.completed or task.task_id in runner.pending:
continue
if isinstance(runner.status, RunnerReady) and all(
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -33,10 +33,15 @@ def entrypoint(
try:
if bound_instance.is_image_model:
from exo.worker.runner.image_models.runner import main
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
else:
from exo.worker.runner.llm_inference.runner import Runner
runner = Runner(
bound_instance, event_sender, task_receiver, cancel_receiver
)
runner.main()
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")

View File

@@ -0,0 +1,178 @@
from collections import deque
from collections.abc import Generator
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
from exo.shared.types.common import ModelId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import TaskId, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import PrefillCancelled, mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
mx_any,
)
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
"""Check for debug prompt triggers in the input."""
import time
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
if len(task_params.input) == 0:
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
raise Exception("Artificial runner exception - for testing purposes only.")
if EXO_RUNNER_MUST_OOM in prompt:
mlx_force_oom()
if EXO_RUNNER_MUST_TIMEOUT in prompt:
time.sleep(100)
@dataclass(eq=False)
class BatchGenerator:
model: Model
tokenizer: TokenizerWrapper
group: mx.distributed.Group | None
kv_prefix_cache: KVPrefixCache | None
model_id: ModelId
device_rank: int
cancel_receiver: MpReceiver[TaskId]
cancelled_tasks: set[TaskId]
event_sender: MpSender[Event]
check_for_cancel_every: int
_queue: deque[tuple[TextGeneration, MpSender[GenerationResponse]]] = field(
default_factory=deque, init=False
)
_active: (
tuple[
TextGeneration,
MpSender[GenerationResponse],
Generator[GenerationResponse],
]
| None
) = field(default=None, init=False)
def submit(
self,
task: TextGeneration,
sender: MpSender[GenerationResponse],
) -> None:
self._queue.append((task, sender))
if self._active is None:
self._start_next()
def step(self) -> None:
if self._active is None:
if self._queue:
self._start_next()
else:
return
if self._active is None:
return
task, sender, gen = self._active
try:
response = next(gen)
sender.send(response)
except (StopIteration, PrefillCancelled):
sender.close()
self._active = None
if self._queue:
self._start_next()
except Exception as e:
self._send_error(task, e)
sender.close()
self._active = None
raise
def _start_next(self) -> None:
task, sender = self._queue.popleft()
try:
gen = self._build_generator(task)
except Exception as e:
self._send_error(task, e)
sender.close()
raise
self._active = (task, sender, gen)
def _send_error(self, task: TextGeneration, e: Exception) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
_check_for_debug_prompts(task.task_params)
prompt = apply_chat_template(self.tokenizer, task.task_params)
def on_prefill_progress(processed: int, total: int) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=PrefillProgressChunk(
model=self.model_id,
processed_tokens=processed,
total_tokens=total,
),
)
)
def distributed_prompt_progress_callback() -> None:
self.cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
tokens_since_cancel_check = self.check_for_cancel_every
def on_generation_token() -> None:
nonlocal tokens_since_cancel_check
tokens_since_cancel_check += 1
if tokens_since_cancel_check >= self.check_for_cancel_every:
tokens_since_cancel_check = 0
self.cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
return mlx_generate(
model=self.model,
tokenizer=self.tokenizer,
task=task.task_params,
prompt=prompt,
kv_prefix_cache=self.kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
on_generation_token=on_generation_token,
group=self.group,
)

View File

@@ -0,0 +1,341 @@
from collections.abc import Generator
from functools import cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ToolCallItem
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.llm_inference.tool_parsers import ToolParser
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
# Debug: log every token with state
logger.debug(
f"parse_gpt_oss token={response.token} text={response.text!r} "
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
f"state={stream.state} current_tool={current_tool_name!r}"
)
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
logger.info(
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
)
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
],
usage=response.usage,
)
tool_arg_parts = []
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
if ch != "analysis" and thinking:
thinking = False
if delta:
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
if response.finish_reason is not None:
yield response
def parse_deepseek_v32(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
Uses accumulated-text matching (not per-token marker checks) because
DSML markers like <DSMLfunction_calls> may span multiple tokens.
Also handles <think>...</think> blocks for thinking mode.
"""
from exo.worker.engines.mlx.dsml_encoding import (
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
parse_dsml_output,
)
accumulated = ""
in_tool_call = False
thinking = False
# Tokens buffered while we detect the start of a DSML block
pending_buffer: list[GenerationResponse] = []
# Text accumulated during a tool call block
tool_call_text = ""
for response in responses:
if response is None:
yield None
continue
# ── Handle thinking tags ──
if not thinking and THINKING_START in response.text:
thinking = True
# Yield any text before the <think> tag
before = response.text[: response.text.index(THINKING_START)]
if before:
yield response.model_copy(update={"text": before})
continue
if thinking and THINKING_END in response.text:
thinking = False
# Yield any text after the </think> tag
after = response.text[
response.text.index(THINKING_END) + len(THINKING_END) :
]
if after:
yield response.model_copy(update={"text": after, "is_thinking": False})
continue
if thinking:
yield response.model_copy(update={"is_thinking": True})
continue
# ── Handle tool call accumulation ──
if in_tool_call:
tool_call_text += response.text
if TOOL_CALLS_END in tool_call_text:
# Parse the accumulated DSML block
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# EOS reached before end marker — yield buffered text as-is
if response.finish_reason is not None:
logger.info("DSML tool call parsing interrupted by EOS")
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# ── Detect start of tool call block ──
accumulated += response.text
if TOOL_CALLS_START in accumulated:
# The start marker might be split across pending_buffer + current token
start_idx = accumulated.index(TOOL_CALLS_START)
# Yield any pending tokens that are purely before the marker
pre_text = accumulated[:start_idx]
if pre_text:
# Flush pending buffer tokens that contributed text before the marker
for buf_resp in pending_buffer:
if pre_text:
chunk = buf_resp.text
if len(chunk) <= len(pre_text):
yield buf_resp
pre_text = pre_text[len(chunk) :]
else:
yield buf_resp.model_copy(update={"text": pre_text})
pre_text = ""
pending_buffer = []
tool_call_text = accumulated[start_idx:]
accumulated = ""
# Check if the end marker is already present (entire tool call in one token)
if TOOL_CALLS_END in tool_call_text:
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
tool_call_text = ""
else:
in_tool_call = True
continue
# Check if accumulated text might be the start of a DSML marker
# Buffer tokens if we see a partial match at the end
if _could_be_dsml_prefix(accumulated):
pending_buffer.append(response)
continue
# No partial match — flush all pending tokens and the current one
for buf_resp in pending_buffer:
yield buf_resp
pending_buffer = []
accumulated = ""
yield response
# Flush any remaining pending buffer at generator end
for buf_resp in pending_buffer:
yield buf_resp
def _could_be_dsml_prefix(text: str) -> bool:
"""Check if the end of text could be the start of a DSML function_calls marker.
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
This allows us to buffer tokens until we can determine if a tool call is starting.
"""
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
# Only check the last portion of text that could overlap with the marker
max_check = len(TOOL_CALLS_START)
tail = text[-max_check:] if len(text) > max_check else text
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
for i in range(len(tail)):
suffix = tail[i:]
if TOOL_CALLS_START.startswith(suffix):
return True
return False
def parse_thinking_models(
responses: Generator[GenerationResponse | None],
tokenizer: TokenizerWrapper,
starts_in_thinking: bool = True,
) -> Generator[GenerationResponse | None]:
"""Route thinking tokens via is_thinking flag.
Swallows think tag tokens, sets is_thinking on all others.
Always yields tokens with finish_reason to avoid hanging the chunk stream.
"""
in_thinking = starts_in_thinking
for response in responses:
if response is None:
yield None
continue
if isinstance(response, ToolCallResponse):
yield response
continue
is_think_tag = (
tokenizer.think_end is not None and response.text == tokenizer.think_end
) or (
tokenizer.think_start is not None and response.text == tokenizer.think_start
)
if is_think_tag:
in_thinking = response.text != tokenizer.think_end
# Never swallow finish_reason — the chunk stream needs it to terminate.
if response.finish_reason is not None:
yield response.model_copy(update={"text": "", "is_thinking": False})
continue
yield response.model_copy(update={"is_thinking": in_thinking})
def parse_tool_calls(
responses: Generator[GenerationResponse | None], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse | None]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
else:
# fallthrough
yield response

View File

File diff suppressed because it is too large Load Diff

View File

@@ -172,7 +172,7 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
self.pending[event.task_id].set()
continue
if (
isinstance(event, TaskStatusUpdated)
@@ -190,6 +190,7 @@ class RunnerSupervisor:
),
)
self.completed.add(event.task_id)
self.pending.pop(event.task_id, None)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -14,9 +14,9 @@ from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load

View File

@@ -9,8 +9,8 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
cache_length,
@@ -143,7 +143,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
# Cache should now hold the prompt tokens minus one
@@ -164,7 +171,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -200,7 +214,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), short_tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
short_tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -245,7 +266,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -285,7 +313,14 @@ class TestKVPrefixCacheWithModel:
cache = make_kv_cache(model)
_, _, snapshots = prefill(
model, tokenizer, make_sampler(0.0), tokens, cache, group=None
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache = KVPrefixCache(None)
@@ -513,7 +548,16 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache.add_kv_cache(tokens, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
@@ -538,7 +582,16 @@ class TestKVPrefixCacheWithModel:
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache, group=None)
prefill(
model,
tokenizer,
make_sampler(0.0),
tokens,
cache,
group=None,
on_prefill_progress=None,
distributed_prompt_progress_callback=None,
)
kv_prefix_cache.add_kv_cache(tokens, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)

View File

@@ -0,0 +1,512 @@
# type: ignore
"""Test that pipeline prefill callbacks and output exactly match stream_generate.
Spins up a single-device (non-pipeline) run and a distributed pipeline run,
then verifies that the prompt_progress_callback sequences are identical
and that generated text matches.
"""
import json
import multiprocessing as mp
import os
import tempfile
import traceback
from typing import Any, cast
import pytest
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
MODEL_PATH = EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8"
TOTAL_LAYERS = 24
MAX_TOKENS = 10
SEED = 42
TEMPERATURE = 0.0
def _model_card() -> ModelCard:
return ModelCard(
model_id=ModelId(MODEL_ID),
storage_size=Memory.from_gb(12),
n_layers=TOTAL_LAYERS,
hidden_size=2880,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
)
def _build_prompt(tokenizer: Any, prompt_tokens: int) -> tuple[str, Any]:
"""Build a prompt with the given number of user-content tokens, return (chat_prompt, task)."""
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
base_text = "The quick brown fox jumps over the lazy dog. "
base_toks = tokenizer.encode(base_text)
repeats = (prompt_tokens // len(base_toks)) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = TextGenerationTaskParams(
model=MODEL_ID,
input=[InputMessage(role="user", content=prompt_text)],
max_output_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
seed=SEED,
)
prompt = apply_chat_template(tokenizer, task)
return prompt, task
# ---------------------------------------------------------------------------
# Single-device process: uses stream_generate path (no pipeline layers)
# ---------------------------------------------------------------------------
def _run_single_device(
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load full model without pipeline sharding, run mlx_generate, record callbacks."""
try:
import mlx.core as mx
from mlx_lm.utils import load_model
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
build_model_path,
get_tokenizer,
)
model_path = build_model_path(ModelId(MODEL_ID))
model, _ = load_model(model_path, lazy=True, strict=False)
mx.eval(model)
# Use PipelineShardMetadata just for get_tokenizer (needs model_card), but
# do NOT apply pipeline sharding — the model keeps all layers unwrapped.
dummy_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=TOTAL_LAYERS,
n_layers=TOTAL_LAYERS,
)
tokenizer = get_tokenizer(model_path, dummy_meta)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=None,
on_prefill_progress=on_progress,
):
generated_text += response.text
if response.finish_reason is not None:
break
# Also record the token count that prefill() received (prompt_tokens[:-1])
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Pipeline device process: uses _pipeline_prefill_cache path
# ---------------------------------------------------------------------------
def _run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
result_queue: Any,
) -> None:
"""Load model with pipeline sharding, run mlx_generate, record callbacks."""
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
try:
import mlx.core as mx
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.cache import encode_prompt
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
group = mx.distributed.init(backend="ring", strict=True)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_card=_model_card(),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=TOTAL_LAYERS,
)
model, tokenizer = shard_and_load(
shard_meta, group, on_timeout=None, on_layer_loaded=None
)
model = cast(Any, model)
prompt, task = _build_prompt(tokenizer, prompt_tokens)
callbacks: list[tuple[int, int]] = []
def on_progress(processed: int, total: int) -> None:
callbacks.append((processed, total))
def distributed_prompt_progress_callback(_group: Any = group) -> None:
from exo.worker.engines.mlx.utils_mlx import mx_any
mx_any(False, _group)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=None,
group=group,
on_prefill_progress=on_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
):
generated_text += response.text
if response.finish_reason is not None:
break
all_tokens = encode_prompt(tokenizer, prompt)
prefill_token_count = len(all_tokens) - 1
result_queue.put(
(
rank,
True,
{
"callbacks": callbacks,
"text": generated_text,
"prefill_token_count": prefill_token_count,
},
)
)
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}"))
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
def _create_hostfile(world_size: int, base_port: int) -> str:
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
return f.name
def _run_single_device_test(prompt_tokens: int, timeout: int = 120) -> dict[str, Any]:
"""Run single-device (stream_generate) prefill and return results."""
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
p = ctx.Process(target=_run_single_device, args=(prompt_tokens, result_queue))
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.terminate()
p.join(timeout=5)
pytest.fail("Single-device process timed out")
assert not result_queue.empty(), "Single-device process produced no result"
success, data = result_queue.get()
assert success, f"Single-device process failed:\n{data}"
return data
def _run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
base_port: int,
timeout: int = 120,
) -> dict[int, dict[str, Any]]:
"""Run pipeline prefill across ranks and return per-rank results."""
world_size = len(layer_splits)
hostfile_path = _create_hostfile(world_size, base_port)
ctx = mp.get_context("spawn")
result_queue: Any = ctx.Queue()
try:
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=_run_pipeline_device,
args=(
rank,
world_size,
hostfile_path,
layer_splits,
prompt_tokens,
result_queue,
),
)
p.start()
processes.append(p)
for p in processes:
p.join(timeout=timeout)
timed_out = any(p.is_alive() for p in processes)
for p in processes:
if p.is_alive():
p.terminate()
p.join(timeout=5)
assert not timed_out, "Pipeline processes timed out"
results: dict[int, dict[str, Any]] = {}
while not result_queue.empty():
rank, success, data = result_queue.get()
assert success, f"Pipeline rank {rank} failed:\n{data}"
results[rank] = data
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}: missing ranks {set(range(world_size)) - results.keys()}"
)
return results
finally:
os.unlink(hostfile_path)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not MODEL_PATH.exists(),
reason=f"GPT-OSS model not found at {MODEL_PATH}",
),
]
LAYER_SPLITS_4WAY: list[tuple[int, int]] = [(0, 6), (6, 12), (12, 18), (18, 24)]
LAYER_SPLITS_2WAY: list[tuple[int, int]] = [(0, 12), (12, 24)]
class TestPipelineNoDeadlock:
"""Pipeline prefill must not deadlock at any rank count or prompt length."""
@pytest.mark.parametrize(
"layer_splits,prompt_tokens",
[
(LAYER_SPLITS_2WAY, 128),
(LAYER_SPLITS_2WAY, 4096),
(LAYER_SPLITS_2WAY, 8192),
(LAYER_SPLITS_2WAY, 16384),
(LAYER_SPLITS_4WAY, 128),
(LAYER_SPLITS_4WAY, 4096),
(LAYER_SPLITS_4WAY, 8192),
(LAYER_SPLITS_4WAY, 16384),
],
ids=[
"2rank_128tok",
"2rank_4096tok",
"2rank_8192tok",
"2rank_16384tok",
"4rank_128tok",
"4rank_4096tok",
"4rank_8192tok",
"4rank_16384tok",
],
)
def test_no_deadlock(
self,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
) -> None:
"""Pipeline must complete without deadlock at various prompt lengths."""
pipeline_results = _run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=prompt_tokens,
base_port=29650,
timeout=60,
)
# If we get here, no deadlock. Verify all ranks produced output.
for rank, pipe_data in sorted(pipeline_results.items()):
assert pipe_data["text"], f"Rank {rank} produced no output text"
class TestPipelinePrefillCallbacks:
"""Verify that pipeline prefill callbacks exactly match stream_generate callbacks."""
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500, 5000],
ids=["short_50", "medium_500", "long_5000"],
)
def test_callbacks_match(self, prompt_tokens: int) -> None:
"""All pipeline ranks must produce identical callback sequences."""
# Run 4-rank pipeline
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29700,
timeout=180,
)
# All ranks must agree on prefill token count and callback sequence
rank0_data = pipeline_results[0]
rank0_callbacks = rank0_data["callbacks"]
prefill_count = rank0_data["prefill_token_count"]
for rank, pipe_data in sorted(pipeline_results.items()):
pipe_callbacks = pipe_data["callbacks"]
assert pipe_data["prefill_token_count"] == prefill_count, (
f"Rank {rank} prefill token count mismatch: "
f"{pipe_data['prefill_token_count']} vs {prefill_count}"
)
assert pipe_callbacks == rank0_callbacks, (
f"Rank {rank} callback mismatch for {prompt_tokens} prompt tokens "
f"(prefill M={prefill_count}):\n"
f" pipeline R0 ({len(rank0_callbacks)} callbacks): {rank0_callbacks}\n"
f" pipeline R{rank} ({len(pipe_callbacks)} callbacks): {pipe_callbacks}"
)
# Structural checks: starts with (0, M), ends with (M, M), monotonically increasing
assert rank0_callbacks[0] == (0, prefill_count), (
f"First callback should be (0, {prefill_count}), got {rank0_callbacks[0]}"
)
assert rank0_callbacks[-1] == (prefill_count, prefill_count), (
f"Last callback should be ({prefill_count}, {prefill_count}), got {rank0_callbacks[-1]}"
)
for i in range(1, len(rank0_callbacks)):
assert rank0_callbacks[i][0] >= rank0_callbacks[i - 1][0], (
f"Callbacks not monotonically increasing at index {i}: {rank0_callbacks}"
)
@pytest.mark.parametrize(
"prompt_tokens",
[50, 500],
ids=["short_50", "medium_500"],
)
def test_output_matches(self, prompt_tokens: int) -> None:
"""Pipeline-generated text must match single-device output."""
single = _run_single_device_test(prompt_tokens, timeout=180)
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29800,
timeout=180,
)
single_text = single["text"]
# The last rank produces the final logits, so its output should match.
# Due to SDPA tiling non-determinism, allow minor differences in text.
last_rank = max(pipeline_results.keys())
pipe_text = pipeline_results[last_rank]["text"]
# For deterministic sampling (temp=0.0), outputs should match exactly
# or be very close. Log both for debugging even if they match.
if single_text != pipe_text:
# Find first divergence point
min_len = min(len(single_text), len(pipe_text))
diverge_idx = next(
(i for i in range(min_len) if single_text[i] != pipe_text[i]),
min_len,
)
pytest.fail(
f"Output text diverged at character {diverge_idx} for {prompt_tokens} prompt tokens:\n"
f" single-device: {single_text!r}\n"
f" pipeline R{last_rank}: {pipe_text!r}"
)
class TestPipelineCallbacksStructure:
"""Verify structural properties of callbacks independent of model output."""
def test_callback_structure_matches_generate_step(self) -> None:
"""Verify callbacks follow generate_step's pattern: (0,M), chunks up to M-1, (M,M)."""
prompt_tokens = 200
pipeline_results = _run_pipeline_test(
layer_splits=LAYER_SPLITS_4WAY,
prompt_tokens=prompt_tokens,
base_port=29900,
timeout=180,
)
for rank, pipe_data in sorted(pipeline_results.items()):
callbacks = pipe_data["callbacks"]
m = pipe_data["prefill_token_count"]
assert m > 0, f"Rank {rank}: prefill token count is 0"
assert callbacks[0] == (0, m), (
f"Rank {rank}: first callback should be (0, {m}), got {callbacks[0]}"
)
assert callbacks[-1] == (m, m), (
f"Rank {rank}: last callback should be ({m}, {m}), got {callbacks[-1]}"
)
if len(callbacks) > 2:
second_to_last = callbacks[-2]
assert second_to_last[0] < m, (
f"Rank {rank}: second-to-last callback should report < {m}, "
f"got {second_to_last}"
)
# All callbacks must have total == M
for i, (_, total) in enumerate(callbacks):
assert total == m, (
f"Rank {rank}: callback {i} has total={total}, expected {m}"
)
# processed values must be non-decreasing
processed_vals = [p for p, _ in callbacks]
for i in range(1, len(processed_vals)):
assert processed_vals[i] >= processed_vals[i - 1], (
f"Rank {rank}: callbacks not non-decreasing at index {i}: "
f"{processed_vals}"
)
# No duplicate consecutive callbacks (pipeline dummies must not emit callbacks)
for i in range(1, len(callbacks)):
assert callbacks[i] != callbacks[i - 1], (
f"Rank {rank}: duplicate consecutive callback at index {i}: "
f"{callbacks[i]} (this suggests dummy iterations are emitting callbacks)"
)

View File

@@ -15,8 +15,8 @@ from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import (

View File

@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
encode_messages,
parse_dsml_output,
)
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32
from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
# ── Shared fixtures ──────────────────────────────────────────────

View File

@@ -6,6 +6,7 @@ from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
import exo.worker.runner.llm_inference.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
@@ -115,17 +116,20 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(
mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
)
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
@@ -183,12 +187,13 @@ def _run(tasks: Iterable[Task]):
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(
runner = mlx_runner.Runner(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
runner.main()
return event_sender.events

View File

@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.runner.llm_inference.runner import parse_gpt_oss
from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary.
@@ -107,7 +107,7 @@ def _collect(
def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens)
return list(parse_gpt_oss(_gen()))
return list(x for x in parse_gpt_oss(_gen()) if x is not None)
def _get_tool_call(

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator
from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.llm_inference.runner import parse_tool_calls
from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser