From 0a549f8846dd49ad0347ef61313e893eed246f07 Mon Sep 17 00:00:00 2001 From: Evan Quiney Date: Wed, 22 Apr 2026 14:03:31 +0100 Subject: [PATCH] remove layer loading callback (#1890) first part of modularising the backend is simplifying some of the control flow. more tbd. --- .../shared/types/worker/runner_response.py | 5 + .../worker/engines/image/distributed_model.py | 5 +- src/exo/worker/engines/image/generate.py | 76 +++++++++--- src/exo/worker/engines/mlx/auto_parallel.py | 87 ++++++------- src/exo/worker/engines/mlx/utils_mlx.py | 41 +++---- src/exo/worker/runner/image_models/runner.py | 116 +++--------------- .../runner/llm_inference/batch_generator.py | 46 +++---- .../llm_inference/model_output_parsers.py | 65 ++++++++-- src/exo/worker/runner/llm_inference/runner.py | 101 ++++----------- .../test_runner/test_event_ordering.py | 20 ++- 10 files changed, 254 insertions(+), 308 deletions(-) diff --git a/src/exo/shared/types/worker/runner_response.py b/src/exo/shared/types/worker/runner_response.py index 27a7a3a0c..e415bdca3 100644 --- a/src/exo/shared/types/worker/runner_response.py +++ b/src/exo/shared/types/worker/runner_response.py @@ -70,6 +70,11 @@ class FinishedResponse(BaseRunnerResponse): pass +class ModelLoadingResponse(BaseRunnerResponse): + layers_loaded: int + total: int + + class PrefillProgressResponse(BaseRunnerResponse): processed_tokens: int total_tokens: int diff --git a/src/exo/worker/engines/image/distributed_model.py b/src/exo/worker/engines/image/distributed_model.py index 7d866a3a5..4c9e7406e 100644 --- a/src/exo/worker/engines/image/distributed_model.py +++ b/src/exo/worker/engines/image/distributed_model.py @@ -8,6 +8,7 @@ from PIL import Image from exo.api.types import AdvancedImageParams from exo.download.download_utils import build_model_path +from exo.shared.types.common import ModelId from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata from exo.worker.engines.image.config import ImageModelConfig @@ -22,13 +23,14 @@ from exo.worker.runner.bootstrap import logger class DistributedImageModel: + model_id: ModelId _config: ImageModelConfig _adapter: ModelAdapter[Any, Any] _runner: DiffusionRunner def __init__( self, - model_id: str, + model_id: ModelId, local_path: Path, shard_metadata: PipelineShardMetadata | CfgShardMetadata, group: Optional[mx.distributed.Group] = None, @@ -68,6 +70,7 @@ class DistributedImageModel: else: logger.info("Single-node initialization") + self.model_id = model_id self._config = config self._adapter = adapter self._runner = runner diff --git a/src/exo/worker/engines/image/generate.py b/src/exo/worker/engines/image/generate.py index fbc238156..d393d5e89 100644 --- a/src/exo/worker/engines/image/generate.py +++ b/src/exo/worker/engines/image/generate.py @@ -3,7 +3,7 @@ import io import random import tempfile import time -from collections.abc import Callable +from collections.abc import Callable, Iterator from pathlib import Path from typing import Generator, Literal @@ -17,11 +17,10 @@ from exo.api.types import ( ImageGenerationTaskParams, ImageSize, ) +from exo.shared.constants import EXO_MAX_CHUNK_SIZE +from exo.shared.types.chunks import ImageChunk +from exo.shared.types.common import ModelId from exo.shared.types.memory import Memory -from exo.shared.types.worker.runner_response import ( - ImageGenerationResponse, - PartialImageResponse, -) from exo.worker.engines.image.distributed_model import DistributedImageModel @@ -71,16 +70,8 @@ def generate_image( model: DistributedImageModel, task: ImageGenerationTaskParams | ImageEditsTaskParams, cancel_checker: Callable[[], bool] | None = None, -) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]: - """Generate image(s), optionally yielding partial results. - - When partial_images > 0 or stream=True, yields PartialImageResponse for - intermediate images, then ImageGenerationResponse for the final image. - - Yields: - PartialImageResponse for intermediate images (if partial_images > 0, first image only) - ImageGenerationResponse for final complete images - """ +) -> Generator[ImageChunk, None, None]: + """Generate image(s), optionally yielding partial results.""" width, height = parse_size(task.size) quality: Literal["low", "medium", "high"] = task.quality or "medium" @@ -142,12 +133,14 @@ def generate_image( image = image.convert("RGB") image.save(buffer, format=image_format) - yield PartialImageResponse( + yield from _process_image_response( image_data=buffer.getvalue(), - format=task.output_format, + image_format=task.output_format, partial_index=partial_idx, total_partials=total_partials, image_index=image_num, + model_id=model.model_id, + stats=None, ) else: image = result @@ -189,9 +182,54 @@ def generate_image( image = image.convert("RGB") image.save(buffer, format=image_format) - yield ImageGenerationResponse( + yield from _process_image_response( image_data=buffer.getvalue(), - format=task.output_format, + image_format=task.output_format, stats=stats, image_index=image_num, + model_id=model.model_id, + partial_index=None, + total_partials=None, ) + + +def _process_image_response( + image_data: bytes, + image_index: int, + image_format: Literal["png", "jpeg", "webp"], + partial_index: int | None, + total_partials: int | None, + stats: ImageGenerationStats | None, + model_id: ModelId, +) -> Iterator[ImageChunk]: + """Process a single image response and send chunks.""" + is_partial = partial_index is not None + encoded_data = base64.b64encode(image_data).decode("utf-8") + # Extract stats from final ImageGenerationResponse if available + data_chunks = [ + encoded_data[i : i + EXO_MAX_CHUNK_SIZE] + for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE) + ] + total_chunks = len(data_chunks) + + def _data_to_chunk(item: tuple[int, str]) -> ImageChunk: + chunk_index, chunk_data = item + # Only include stats on the last chunk of the final image + chunk_stats = ( + stats if chunk_index == total_chunks - 1 and not is_partial else None + ) + + return ImageChunk( + model=model_id, + data=chunk_data, + chunk_index=chunk_index, + total_chunks=total_chunks, + image_index=image_index, + is_partial=is_partial, + partial_index=partial_index, + total_partials=total_partials, + stats=chunk_stats, + format=image_format, + ) + + return map(_data_to_chunk, enumerate(data_chunks)) diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 64df97f2d..398595ee7 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Generator from functools import partial from inspect import signature from typing import TYPE_CHECKING, Literal, Protocol, cast @@ -59,14 +59,13 @@ from mlx_lm.models.step3p5 import Model as Step35Model from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel +from exo.shared.types.worker.runner_response import ModelLoadingResponse from exo.shared.types.worker.shards import PipelineShardMetadata from exo.worker.runner.bootstrap import logger if TYPE_CHECKING: from mlx_lm.models.cache import Cache -LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers) - _pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = [] @@ -276,8 +275,7 @@ def pipeline_auto_parallel( model: nn.Module, group: mx.distributed.Group, model_shard_meta: PipelineShardMetadata, - on_layer_loaded: LayerLoadedCallback | None, -) -> nn.Module: +) -> Generator[ModelLoadingResponse, None, nn.Module]: """ Automatically parallelize a model across multiple devices. Args: @@ -297,8 +295,7 @@ def pipeline_auto_parallel( total = len(layers) for i, layer in enumerate(layers): mx.eval(layer) # type: ignore - if on_layer_loaded is not None: - on_layer_loaded(i, total) + yield ModelLoadingResponse(layers_loaded=i, total=total) layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group) layers[-1] = PipelineLastLayer( @@ -460,8 +457,7 @@ def patch_tensor_model[T](model: T) -> T: def tensor_auto_parallel( model: nn.Module, group: mx.distributed.Group, - on_layer_loaded: LayerLoadedCallback | None, -) -> nn.Module: +) -> Generator[ModelLoadingResponse, None, nn.Module]: all_to_sharded_linear = partial( shard_linear, sharding="all-to-sharded", @@ -595,7 +591,7 @@ def tensor_auto_parallel( else: raise ValueError(f"Unsupported model type: {type(model)}") - model = tensor_parallel_sharding_strategy.shard_model(model, on_layer_loaded) + model = yield from tensor_parallel_sharding_strategy.shard_model(model) return patch_tensor_model(model) @@ -619,16 +615,14 @@ class TensorParallelShardingStrategy(ABC): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: ... + ) -> Generator[ModelLoadingResponse, None, nn.Module]: ... class LlamaShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(LlamaModel, model) total = len(model.layers) for i, layer in enumerate(model.layers): @@ -646,8 +640,8 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy): layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj) layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj) mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -681,8 +675,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(DeepseekV3Model, model) total = len(model.layers) @@ -738,8 +731,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy): layer.mlp.sharding_group = self.group mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -764,8 +757,7 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(GLM4MoeLiteModel, model) total = len(model.layers) # type: ignore for i, layer in enumerate(model.layers): # type: ignore @@ -816,8 +808,8 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy): layer.mlp = ShardedMoE(layer.mlp) # type: ignore layer.mlp.sharding_group = self.group # type: ignore mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -904,8 +896,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(MiniMaxModel, model) total = len(model.layers) for i, layer in enumerate(model.layers): @@ -934,8 +925,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy): layer.block_sparse_moe = ShardedMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue] mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -943,8 +934,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast( Qwen3Model | Qwen3MoeModel @@ -1099,8 +1089,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy): layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj) mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -1108,8 +1098,7 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(Glm4MoeModel, model) total = len(model.layers) for i, layer in enumerate(model.layers): @@ -1145,8 +1134,8 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy): layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj) mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -1154,8 +1143,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(GptOssMoeModel, model) total = len(model.layers) @@ -1186,8 +1174,8 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy): layer.mlp = ShardedMoE(layer.mlp) # type: ignore layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue] mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -1195,8 +1183,7 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(Step35Model, model) total = len(model.layers) @@ -1229,8 +1216,8 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy): self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj) mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + + yield ModelLoadingResponse(layers_loaded=i, total=total) return model @@ -1238,8 +1225,7 @@ class NemotronHShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(NemotronHModel, model) rank = self.group.rank() total = len(model.layers) @@ -1272,8 +1258,7 @@ class NemotronHShardingStrategy(TensorParallelShardingStrategy): layer.mixer = mixer # pyright: ignore[reportAttributeAccessIssue] mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + yield ModelLoadingResponse(layers_loaded=i, total=total) return model def _shard_mamba2_mixer(self, mixer: NemotronHMamba2Mixer, rank: int) -> None: @@ -1380,8 +1365,7 @@ class Gemma4ShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, - on_layer_loaded: LayerLoadedCallback | None, - ) -> nn.Module: + ) -> Generator[ModelLoadingResponse, None, nn.Module]: model = cast(Gemma4Model, model) layers = model.language_model.model.layers total = len(layers) @@ -1409,6 +1393,5 @@ class Gemma4ShardingStrategy(TensorParallelShardingStrategy): layer.experts.sharding_group = self.group mx.eval(layer) - if on_layer_loaded is not None: - on_layer_loaded(i, total) + yield ModelLoadingResponse(layers_loaded=i, total=total) return model diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 7ad45a16f..2b28c7ff3 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -4,6 +4,7 @@ import re import sys import tempfile import time +from collections.abc import Generator from pathlib import Path from typing import TYPE_CHECKING, Any, cast @@ -51,6 +52,7 @@ from exo.shared.types.worker.instances import ( MlxJacclInstance, MlxRingInstance, ) +from exo.shared.types.worker.runner_response import ModelLoadingResponse from exo.shared.types.worker.shards import ( CfgShardMetadata, PipelineShardMetadata, @@ -58,7 +60,6 @@ from exo.shared.types.worker.shards import ( TensorShardMetadata, ) from exo.worker.engines.mlx.auto_parallel import ( - LayerLoadedCallback, get_inner_model, get_layers, pipeline_auto_parallel, @@ -66,8 +67,6 @@ from exo.worker.engines.mlx.auto_parallel import ( ) from exo.worker.runner.bootstrap import logger -Group = mx.distributed.Group - def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: return Memory.from_float_kb( @@ -90,7 +89,7 @@ class HostList(RootModel[list[str]]): def mlx_distributed_init( bound_instance: BoundInstance, -) -> Group: +) -> mx.distributed.Group: """ Initialize MLX distributed. """ @@ -149,7 +148,7 @@ def mlx_distributed_init( def initialize_mlx( bound_instance: BoundInstance, -) -> Group: +) -> mx.distributed.Group: # should we unseed it? # TODO: pass in seed from params mx.random.seed(42) @@ -162,9 +161,10 @@ def initialize_mlx( def load_mlx_items( bound_instance: BoundInstance, - group: Group | None, - on_layer_loaded: LayerLoadedCallback | None, -) -> "tuple[Model, TokenizerWrapper, VisionProcessor | None]": + group: mx.distributed.Group | None, +) -> Generator[ + ModelLoadingResponse, None, tuple[Model, TokenizerWrapper, "VisionProcessor | None"] +]: if group is None: logger.info(f"Single device used for {bound_instance.instance}") model_path = build_model_path(bound_instance.bound_shard.model_card.model_id) @@ -177,8 +177,7 @@ def load_mlx_items( total = len(layers) for i, layer in enumerate(layers): mx.eval(layer) # type: ignore - if on_layer_loaded is not None: - on_layer_loaded(i, total) + yield ModelLoadingResponse(layers_loaded=i, total=total) except ValueError as e: logger.opt(exception=e).debug( "Model architecture doesn't support layer-by-layer progress tracking", @@ -191,10 +190,9 @@ def load_mlx_items( else: logger.info("Starting distributed init") start_time = time.perf_counter() - model, tokenizer = shard_and_load( + model, tokenizer = yield from shard_and_load( bound_instance.bound_shard, group=group, - on_layer_loaded=on_layer_loaded, ) end_time = time.perf_counter() logger.info( @@ -221,9 +219,8 @@ def load_mlx_items( def shard_and_load( shard_metadata: ShardMetadata, - group: Group, - on_layer_loaded: LayerLoadedCallback | None, -) -> tuple[nn.Module, TokenizerWrapper]: + group: mx.distributed.Group, +) -> Generator[ModelLoadingResponse, None, tuple[nn.Module, TokenizerWrapper]]: model_path = build_model_path(shard_metadata.model_card.model_id) model, _ = load_model(model_path, lazy=True, strict=False) @@ -254,12 +251,10 @@ def shard_and_load( match shard_metadata: case TensorShardMetadata(): logger.info(f"loading model from {model_path} with tensor parallelism") - model = tensor_auto_parallel(model, group, on_layer_loaded) + model = yield from tensor_auto_parallel(model, group) case PipelineShardMetadata(): logger.info(f"loading model from {model_path} with pipeline parallelism") - model = pipeline_auto_parallel( - model, group, shard_metadata, on_layer_loaded=on_layer_loaded - ) + model = yield from pipeline_auto_parallel(model, group, shard_metadata) mx.eval(model.parameters()) case CfgShardMetadata(): raise ValueError( @@ -748,7 +743,9 @@ def set_wired_limit_for_model(model_size: Memory): def mlx_cleanup( - model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None + model: Model | None, + tokenizer: TokenizerWrapper | None, + group: mx.distributed.Group | None, ) -> None: del model, tokenizer, group mx.clear_cache() @@ -757,7 +754,7 @@ def mlx_cleanup( gc.collect() -def mx_any(bool_: bool, group: Group | None) -> bool: +def mx_any(bool_: bool, group: mx.distributed.Group | None) -> bool: if group is None: return bool_ num_true = mx.distributed.all_sum( @@ -767,7 +764,7 @@ def mx_any(bool_: bool, group: Group | None) -> bool: return num_true.item() > 0 -def mx_barrier(group: Group | None): +def mx_barrier(group: mx.distributed.Group | None): if group is None: return mx.eval( diff --git a/src/exo/worker/runner/image_models/runner.py b/src/exo/worker/runner/image_models/runner.py index 2eb90baec..e743c4ce4 100644 --- a/src/exo/worker/runner/image_models/runner.py +++ b/src/exo/worker/runner/image_models/runner.py @@ -1,19 +1,17 @@ -import base64 import time -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING import mlx.core as mx from exo.api.types import ( ImageEditsTaskParams, - ImageGenerationStats, ImageGenerationTaskParams, ) -from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED +from exo.shared.constants import EXO_TRACING_ENABLED from exo.shared.models.model_cards import ModelTask from exo.shared.tracing import clear_trace_buffer, get_trace_buffer -from exo.shared.types.chunks import ErrorChunk, ImageChunk -from exo.shared.types.common import CommandId, ModelId +from exo.shared.types.chunks import ErrorChunk +from exo.shared.types.common import CommandId from exo.shared.types.events import ( ChunkGenerated, Event, @@ -36,10 +34,6 @@ from exo.shared.types.tasks import ( TaskStatus, ) from exo.shared.types.worker.instances import BoundInstance -from exo.shared.types.worker.runner_response import ( - ImageGenerationResponse, - PartialImageResponse, -) from exo.shared.types.worker.runners import ( RunnerConnected, RunnerConnecting, @@ -87,32 +81,6 @@ def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool: return False -def _process_image_response( - response: ImageGenerationResponse | PartialImageResponse, - command_id: CommandId, - shard_metadata: ShardMetadata, - event_sender: MpSender[Event], - image_index: int, -) -> None: - """Process a single image response and send chunks.""" - encoded_data = base64.b64encode(response.image_data).decode("utf-8") - is_partial = isinstance(response, PartialImageResponse) - # Extract stats from final ImageGenerationResponse if available - stats = response.stats if isinstance(response, ImageGenerationResponse) else None - _send_image_chunk( - encoded_data=encoded_data, - command_id=command_id, - model_id=shard_metadata.model_card.model_id, - event_sender=event_sender, - image_index=response.image_index, - is_partial=is_partial, - partial_index=response.partial_index if is_partial else None, - total_partials=response.total_partials if is_partial else None, - stats=stats, - image_format=response.format, - ) - - def _send_traces_if_enabled( event_sender: MpSender[Event], task_id: TaskId, @@ -143,48 +111,6 @@ def _send_traces_if_enabled( clear_trace_buffer() -def _send_image_chunk( - encoded_data: str, - command_id: CommandId, - model_id: ModelId, - event_sender: MpSender[Event], - image_index: int, - is_partial: bool, - partial_index: int | None = None, - total_partials: int | None = None, - stats: ImageGenerationStats | None = None, - image_format: Literal["png", "jpeg", "webp"] | None = None, -) -> None: - """Send base64-encoded image data as chunks via events.""" - data_chunks = [ - encoded_data[i : i + EXO_MAX_CHUNK_SIZE] - for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE) - ] - total_chunks = len(data_chunks) - for chunk_index, chunk_data in enumerate(data_chunks): - # Only include stats on the last chunk of the final image - chunk_stats = ( - stats if chunk_index == total_chunks - 1 and not is_partial else None - ) - event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=ImageChunk( - model=model_id, - data=chunk_data, - chunk_index=chunk_index, - total_chunks=total_chunks, - image_index=image_index, - is_partial=is_partial, - partial_index=partial_index, - total_partials=total_partials, - stats=chunk_stats, - format=image_format, - ), - ) - ) - - class Runner: def __init__( self, @@ -261,35 +187,21 @@ class Runner: return self._check_cancelled(task.task_id) try: - image_index = 0 - for response in generate_image( + for chunk in generate_image( model=self.image_model, task=task_params, cancel_checker=cancel_checker, ): if _is_primary_output_node(self.shard_metadata): - match response: - case PartialImageResponse(): - logger.info( - f"sending partial ImageChunk {response.partial_index}/{response.total_partials}" - ) - _process_image_response( - response, - command_id, - self.shard_metadata, - self.event_sender, - image_index, - ) - case ImageGenerationResponse(): - logger.info("sending final ImageChunk") - _process_image_response( - response, - command_id, - self.shard_metadata, - self.event_sender, - image_index, - ) - image_index += 1 + if chunk.is_partial: + logger.info( + f"sending partial ImageChunk {chunk.partial_index}/{chunk.total_partials}" + ) + else: + logger.info("sending final ImageChunk") + self.event_sender.send( + ChunkGenerated(command_id=command_id, chunk=chunk) + ) except Exception as e: if _is_primary_output_node(self.shard_metadata): self.event_sender.send( diff --git a/src/exo/worker/runner/llm_inference/batch_generator.py b/src/exo/worker/runner/llm_inference/batch_generator.py index 5284d53ae..ee536a993 100644 --- a/src/exo/worker/runner/llm_inference/batch_generator.py +++ b/src/exo/worker/runner/llm_inference/batch_generator.py @@ -2,20 +2,20 @@ import itertools import time from abc import ABC, abstractmethod from collections import deque -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterator from dataclasses import dataclass, field import mlx.core as mx from mlx_lm.tokenizer_utils import TokenizerWrapper from exo.shared.constants import EXO_MAX_CONCURRENT_REQUESTS -from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk +from exo.shared.types.chunks import ErrorChunk, GenerationChunk, PrefillProgressChunk from exo.shared.types.common import ModelId from exo.shared.types.events import ChunkGenerated, Event from exo.shared.types.mlx import Model from exo.shared.types.tasks import CANCEL_ALL_TASKS, TaskId, TextGeneration from exo.shared.types.text_generation import TextGenerationTaskParams -from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse +from exo.shared.types.worker.runner_response import GenerationResponse from exo.utils.channels import MpReceiver, MpSender from exo.worker.engines.mlx.cache import KVPrefixCache from exo.worker.engines.mlx.generator.batch_generate import ExoBatchGenerator @@ -32,7 +32,7 @@ from exo.worker.engines.mlx.utils_mlx import ( from exo.worker.engines.mlx.vision import VisionProcessor from exo.worker.runner.bootstrap import logger -from .model_output_parsers import apply_all_parsers +from .model_output_parsers import apply_all_parsers, map_responses_to_chunks from .tool_parsers import ToolParser @@ -80,9 +80,7 @@ class InferenceGenerator(ABC): @abstractmethod def step( self, - ) -> Iterable[ - tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished] - ]: ... + ) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: ... @abstractmethod def close(self) -> None: ... @@ -137,7 +135,7 @@ class SequentialGenerator(InferenceGenerator): # queue that the 1st generator should push to and 3rd generator should pull from GeneratorQueue[GenerationResponse], # generator to get parsed outputs - Generator[GenerationResponse | ToolCallResponse | None], + Iterator[GenerationChunk | None], ] | None ) = field(default=None, init=False) @@ -183,9 +181,7 @@ class SequentialGenerator(InferenceGenerator): def step( self, - ) -> Iterable[ - tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished] - ]: + ) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: if self._active is None: self.agree_on_tasks() @@ -197,9 +193,7 @@ class SequentialGenerator(InferenceGenerator): assert self._active is not None task, mlx_gen, queue, output_generator = self._active - output: list[ - tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished] - ] = [] + output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = [] try: response = next(mlx_gen) queue.push(response) @@ -233,7 +227,9 @@ class SequentialGenerator(InferenceGenerator): queue = GeneratorQueue[GenerationResponse]() if task.task_params.bench: - output_generator = queue.gen() + output_generator: Iterator[GenerationChunk | None] = map( + lambda r: map_responses_to_chunks(r, self.model_id), queue.gen() + ) else: output_generator = apply_all_parsers( queue.gen(), @@ -338,7 +334,7 @@ class BatchGenerator(InferenceGenerator): tuple[ TextGeneration, GeneratorQueue[GenerationResponse], - Generator[GenerationResponse | ToolCallResponse | None], + Iterator[GenerationChunk | None], ], ] = field(default_factory=dict, init=False) @@ -392,9 +388,7 @@ class BatchGenerator(InferenceGenerator): def step( self, - ) -> Iterable[ - tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished] - ]: + ) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: if not self._queue: self.agree_on_tasks() @@ -411,7 +405,9 @@ class BatchGenerator(InferenceGenerator): queue = GeneratorQueue[GenerationResponse]() if task.task_params.bench: - output_generator = queue.gen() + output_generator: Iterator[GenerationChunk | None] = map( + lambda r: map_responses_to_chunks(r, self.model_id), queue.gen() + ) else: output_generator = apply_all_parsers( queue.gen(), @@ -429,9 +425,7 @@ class BatchGenerator(InferenceGenerator): results = self._mlx_gen.step() - output: list[ - tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished] - ] = [] + output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = [] for uid, response in results: if uid not in self._active_tasks: # should we error here? @@ -453,9 +447,9 @@ class BatchGenerator(InferenceGenerator): def _apply_cancellations( self, - ) -> list[tuple[TaskId, Cancelled]]: + ) -> Iterator[tuple[TaskId, Cancelled]]: if not self._cancelled_tasks: - return [] + return iter([]) cancel_all = CANCEL_ALL_TASKS in self._cancelled_tasks @@ -477,7 +471,7 @@ class BatchGenerator(InferenceGenerator): results.append((tid, Cancelled())) self._cancelled_tasks.clear() - return results + return iter(results) def _send_error(self, task: TextGeneration, e: Exception) -> None: if self.device_rank == 0: diff --git a/src/exo/worker/runner/llm_inference/model_output_parsers.py b/src/exo/worker/runner/llm_inference/model_output_parsers.py index 27a24db02..906a26558 100644 --- a/src/exo/worker/runner/llm_inference/model_output_parsers.py +++ b/src/exo/worker/runner/llm_inference/model_output_parsers.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import Generator, Iterator from functools import cache from typing import Any @@ -14,6 +14,12 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs] ) from exo.api.types import ToolCallItem +from exo.shared.types.chunks import ( + ErrorChunk, + GenerationChunk, + TokenChunk, + ToolCallChunk, +) from exo.shared.types.common import ModelId from exo.shared.types.mlx import Model from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse @@ -64,29 +70,70 @@ def apply_all_parsers( model_type: type[Model], model_id: ModelId, tools: list[dict[str, Any]] | None, -) -> Generator[GenerationResponse | ToolCallResponse | None]: - mlx_generator = receiver +) -> Iterator[GenerationChunk | None]: + generator = receiver if issubclass(model_type, GptOssModel): - mlx_generator = parse_gpt_oss(mlx_generator) + generator = parse_gpt_oss(generator) elif ( issubclass(model_type, DeepseekV32Model) and "deepseek" in model_id.normalize().lower() ): - mlx_generator = parse_deepseek_v32(mlx_generator) + generator = parse_deepseek_v32(generator) else: if tokenizer.has_thinking: - mlx_generator = parse_thinking_models( - mlx_generator, + generator = parse_thinking_models( + generator, tokenizer.think_start, tokenizer.think_end, starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer), ) if tool_parser: - mlx_generator = parse_tool_calls(mlx_generator, tool_parser, tools) + generator = parse_tool_calls(generator, tool_parser, tools) - return count_reasoning_tokens(mlx_generator) + generator = count_reasoning_tokens(generator) + + return map(lambda r: map_responses_to_chunks(r, model_id), generator) + + +def map_responses_to_chunks( + response: GenerationResponse | ToolCallResponse | None, model_id: ModelId +) -> GenerationChunk | None: + match response: + case None: + return None + case GenerationResponse(): + if response.finish_reason == "error": + return ErrorChunk( + error_message=response.text, + model=model_id, + ) + else: + finish_reason = response.finish_reason + assert finish_reason not in ( + "error", + "tool_calls", + "function_call", + ) + return TokenChunk( + model=model_id, + text=response.text, + token_id=response.token, + usage=response.usage, + finish_reason=finish_reason, + stats=response.stats, + logprob=response.logprob, + top_logprobs=response.top_logprobs, + is_thinking=response.is_thinking, + ) + case ToolCallResponse(): + return ToolCallChunk( + tool_calls=response.tool_calls, + model=model_id, + usage=response.usage, + stats=response.stats, + ) def parse_gpt_oss( diff --git a/src/exo/worker/runner/llm_inference/runner.py b/src/exo/worker/runner/llm_inference/runner.py index 556501137..9cc2506a2 100644 --- a/src/exo/worker/runner/llm_inference/runner.py +++ b/src/exo/worker/runner/llm_inference/runner.py @@ -1,5 +1,6 @@ import os import time +from collections.abc import Generator from dataclasses import dataclass from enum import Enum @@ -8,11 +9,7 @@ from anyio import WouldBlock from mlx_lm.tokenizer_utils import TokenizerWrapper from exo.shared.models.model_cards import ModelTask -from exo.shared.types.chunks import ( - ErrorChunk, - TokenChunk, - ToolCallChunk, -) +from exo.shared.types.chunks import GenerationChunk from exo.shared.types.common import CommandId, ModelId from exo.shared.types.events import ( ChunkGenerated, @@ -34,8 +31,7 @@ from exo.shared.types.tasks import ( ) from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.runner_response import ( - GenerationResponse, - ToolCallResponse, + ModelLoadingResponse, ) from exo.shared.types.worker.runners import ( RunnerConnected, @@ -181,23 +177,28 @@ class Runner: ) self.acknowledge_task(task) - def on_layer_loaded(layers_loaded: int, total: int) -> None: - self.update_status( - RunnerLoading(layers_loaded=layers_loaded, total_layers=total) - ) - assert ( ModelTask.TextGeneration in self.shard_metadata.model_card.tasks ), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}" - ( - self.generator.inference_model, - self.generator.tokenizer, - self.generator.vision_processor, - ) = load_mlx_items( - self.bound_instance, - self.generator.group, - on_layer_loaded=on_layer_loaded, - ) + + def load_model() -> Generator[ModelLoadingResponse]: + assert isinstance(self.generator, Builder) + ( + self.generator.inference_model, + self.generator.tokenizer, + self.generator.vision_processor, + ) = yield from load_mlx_items( + self.bound_instance, + self.generator.group, + ) + + for load_resp in load_model(): + self.update_status( + RunnerLoading( + layers_loaded=load_resp.layers_loaded, + total_layers=load_resp.total, + ) + ) self.generator = self.generator.build() @@ -278,9 +279,7 @@ class Runner: self.send_task_status(task_id, TaskStatus.Complete) finished.append(task_id) case _: - self.send_response( - result, self.active_tasks[task_id].command_id - ) + self.send_chunk(result, self.active_tasks[task_id].command_id) for task_id in finished: self.active_tasks.pop(task_id, None) @@ -313,59 +312,13 @@ class Runner: return ExitCode.AllTasksComplete - def send_response( + def send_chunk( self, - response: GenerationResponse | ToolCallResponse, + chunk: GenerationChunk, command_id: CommandId, ): - match response: - case GenerationResponse(): - if self.device_rank == 0 and response.finish_reason == "error": - self.event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=ErrorChunk( - error_message=response.text, - model=self.model_id, - ), - ) - ) - - elif self.device_rank == 0: - assert response.finish_reason not in ( - "error", - "tool_calls", - "function_call", - ) - self.event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=TokenChunk( - model=self.model_id, - text=response.text, - token_id=response.token, - usage=response.usage, - finish_reason=response.finish_reason, - stats=response.stats, - logprob=response.logprob, - top_logprobs=response.top_logprobs, - is_thinking=response.is_thinking, - ), - ) - ) - case ToolCallResponse(): - if self.device_rank == 0: - self.event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=ToolCallChunk( - tool_calls=response.tool_calls, - model=self.model_id, - usage=response.usage, - stats=response.stats, - ), - ) - ) + if self.device_rank == 0: + self.event_sender.send(ChunkGenerated(command_id=command_id, chunk=chunk)) @dataclass diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py index 62b0840ea..a4133b1ca 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -1,6 +1,7 @@ # Check tasks are complete before runner is ever ready. import unittest.mock from collections.abc import Iterable +from dataclasses import dataclass from typing import Callable import mlx.core as mx @@ -115,13 +116,22 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even assert test_event == true_event, f"{test_event} != {true_event}" +@dataclass +class MockLoadOutput: + layers_loaded: int + total: int + + @pytest.fixture def patch_out_mlx(monkeypatch: pytest.MonkeyPatch): # initialize_mlx returns a mock group monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup())) - monkeypatch.setattr( - mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer, None)) - ) + + def lmi_gen(): + yield MockLoadOutput(1, 1) + return (1, MockTokenizer, None) + + monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin(lmi_gen())) monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1)) monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin) monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False)) @@ -318,6 +328,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch): runner_status=RunnerLoading(layers_loaded=0, total_layers=32), ), TaskAcknowledged(task_id=LOAD_TASK_ID), + RunnerStatusUpdated( + runner_id=RUNNER_1_ID, + runner_status=RunnerLoading(layers_loaded=1, total_layers=1), + ), TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete), RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()), TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),