remove layer loading callback (#1890)

first part of modularising the backend is simplifying some of the
control flow. more tbd.
This commit is contained in:
Evan Quiney
2026-04-22 14:03:31 +01:00
committed by GitHub
parent df332035ef
commit 0a549f8846
10 changed files with 254 additions and 308 deletions

View File

@@ -70,6 +70,11 @@ class FinishedResponse(BaseRunnerResponse):
pass
class ModelLoadingResponse(BaseRunnerResponse):
layers_loaded: int
total: int
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -8,6 +8,7 @@ from PIL import Image
from exo.api.types import AdvancedImageParams
from exo.download.download_utils import build_model_path
from exo.shared.types.common import ModelId
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
@@ -22,13 +23,14 @@ from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
model_id: ModelId
_config: ImageModelConfig
_adapter: ModelAdapter[Any, Any]
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
model_id: ModelId,
local_path: Path,
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
group: Optional[mx.distributed.Group] = None,
@@ -68,6 +70,7 @@ class DistributedImageModel:
else:
logger.info("Single-node initialization")
self.model_id = model_id
self._config = config
self._adapter = adapter
self._runner = runner

View File

@@ -3,7 +3,7 @@ import io
import random
import tempfile
import time
from collections.abc import Callable
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import Generator, Literal
@@ -17,11 +17,10 @@ from exo.api.types import (
ImageGenerationTaskParams,
ImageSize,
)
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.types.chunks import ImageChunk
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.worker.engines.image.distributed_model import DistributedImageModel
@@ -71,16 +70,8 @@ def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsTaskParams,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
"""
) -> Generator[ImageChunk, None, None]:
"""Generate image(s), optionally yielding partial results."""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
@@ -142,12 +133,14 @@ def generate_image(
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield PartialImageResponse(
yield from _process_image_response(
image_data=buffer.getvalue(),
format=task.output_format,
image_format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
image_index=image_num,
model_id=model.model_id,
stats=None,
)
else:
image = result
@@ -189,9 +182,54 @@ def generate_image(
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
yield from _process_image_response(
image_data=buffer.getvalue(),
format=task.output_format,
image_format=task.output_format,
stats=stats,
image_index=image_num,
model_id=model.model_id,
partial_index=None,
total_partials=None,
)
def _process_image_response(
image_data: bytes,
image_index: int,
image_format: Literal["png", "jpeg", "webp"],
partial_index: int | None,
total_partials: int | None,
stats: ImageGenerationStats | None,
model_id: ModelId,
) -> Iterator[ImageChunk]:
"""Process a single image response and send chunks."""
is_partial = partial_index is not None
encoded_data = base64.b64encode(image_data).decode("utf-8")
# Extract stats from final ImageGenerationResponse if available
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
def _data_to_chunk(item: tuple[int, str]) -> ImageChunk:
chunk_index, chunk_data = item
# Only include stats on the last chunk of the final image
chunk_stats = (
stats if chunk_index == total_chunks - 1 and not is_partial else None
)
return ImageChunk(
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=is_partial,
partial_index=partial_index,
total_partials=total_partials,
stats=chunk_stats,
format=image_format,
)
return map(_data_to_chunk, enumerate(data_chunks))

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Generator
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Literal, Protocol, cast
@@ -59,14 +59,13 @@ from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from exo.shared.types.worker.runner_response import ModelLoadingResponse
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.bootstrap import logger
if TYPE_CHECKING:
from mlx_lm.models.cache import Cache
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
@@ -276,8 +275,7 @@ def pipeline_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
model_shard_meta: PipelineShardMetadata,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
"""
Automatically parallelize a model across multiple devices.
Args:
@@ -297,8 +295,7 @@ def pipeline_auto_parallel(
total = len(layers)
for i, layer in enumerate(layers):
mx.eval(layer) # type: ignore
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
@@ -460,8 +457,7 @@ def patch_tensor_model[T](model: T) -> T:
def tensor_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
all_to_sharded_linear = partial(
shard_linear,
sharding="all-to-sharded",
@@ -595,7 +591,7 @@ def tensor_auto_parallel(
else:
raise ValueError(f"Unsupported model type: {type(model)}")
model = tensor_parallel_sharding_strategy.shard_model(model, on_layer_loaded)
model = yield from tensor_parallel_sharding_strategy.shard_model(model)
return patch_tensor_model(model)
@@ -619,16 +615,14 @@ class TensorParallelShardingStrategy(ABC):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module: ...
) -> Generator[ModelLoadingResponse, None, nn.Module]: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(LlamaModel, model)
total = len(model.layers)
for i, layer in enumerate(model.layers):
@@ -646,8 +640,8 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -681,8 +675,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(DeepseekV3Model, model)
total = len(model.layers)
@@ -738,8 +731,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.sharding_group = self.group
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -764,8 +757,7 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(GLM4MoeLiteModel, model)
total = len(model.layers) # type: ignore
for i, layer in enumerate(model.layers): # type: ignore
@@ -816,8 +808,8 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # type: ignore
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -904,8 +896,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(MiniMaxModel, model)
total = len(model.layers)
for i, layer in enumerate(model.layers):
@@ -934,8 +925,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.block_sparse_moe = ShardedMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -943,8 +934,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(
Qwen3Model
| Qwen3MoeModel
@@ -1099,8 +1089,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -1108,8 +1098,7 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(Glm4MoeModel, model)
total = len(model.layers)
for i, layer in enumerate(model.layers):
@@ -1145,8 +1134,8 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -1154,8 +1143,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(GptOssMoeModel, model)
total = len(model.layers)
@@ -1186,8 +1174,8 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -1195,8 +1183,7 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(Step35Model, model)
total = len(model.layers)
@@ -1229,8 +1216,8 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
@@ -1238,8 +1225,7 @@ class NemotronHShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(NemotronHModel, model)
rank = self.group.rank()
total = len(model.layers)
@@ -1272,8 +1258,7 @@ class NemotronHShardingStrategy(TensorParallelShardingStrategy):
layer.mixer = mixer # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model
def _shard_mamba2_mixer(self, mixer: NemotronHMamba2Mixer, rank: int) -> None:
@@ -1380,8 +1365,7 @@ class Gemma4ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> Generator[ModelLoadingResponse, None, nn.Module]:
model = cast(Gemma4Model, model)
layers = model.language_model.model.layers
total = len(layers)
@@ -1409,6 +1393,5 @@ class Gemma4ShardingStrategy(TensorParallelShardingStrategy):
layer.experts.sharding_group = self.group
mx.eval(layer)
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
return model

View File

@@ -4,6 +4,7 @@ import re
import sys
import tempfile
import time
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
@@ -51,6 +52,7 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runner_response import ModelLoadingResponse
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
@@ -58,7 +60,6 @@ from exo.shared.types.worker.shards import (
TensorShardMetadata,
)
from exo.worker.engines.mlx.auto_parallel import (
LayerLoadedCallback,
get_inner_model,
get_layers,
pipeline_auto_parallel,
@@ -66,8 +67,6 @@ from exo.worker.engines.mlx.auto_parallel import (
)
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
@@ -90,7 +89,7 @@ class HostList(RootModel[list[str]]):
def mlx_distributed_init(
bound_instance: BoundInstance,
) -> Group:
) -> mx.distributed.Group:
"""
Initialize MLX distributed.
"""
@@ -149,7 +148,7 @@ def mlx_distributed_init(
def initialize_mlx(
bound_instance: BoundInstance,
) -> Group:
) -> mx.distributed.Group:
# should we unseed it?
# TODO: pass in seed from params
mx.random.seed(42)
@@ -162,9 +161,10 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance,
group: Group | None,
on_layer_loaded: LayerLoadedCallback | None,
) -> "tuple[Model, TokenizerWrapper, VisionProcessor | None]":
group: mx.distributed.Group | None,
) -> Generator[
ModelLoadingResponse, None, tuple[Model, TokenizerWrapper, "VisionProcessor | None"]
]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
@@ -177,8 +177,7 @@ def load_mlx_items(
total = len(layers)
for i, layer in enumerate(layers):
mx.eval(layer) # type: ignore
if on_layer_loaded is not None:
on_layer_loaded(i, total)
yield ModelLoadingResponse(layers_loaded=i, total=total)
except ValueError as e:
logger.opt(exception=e).debug(
"Model architecture doesn't support layer-by-layer progress tracking",
@@ -191,10 +190,9 @@ def load_mlx_items(
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
model, tokenizer = yield from shard_and_load(
bound_instance.bound_shard,
group=group,
on_layer_loaded=on_layer_loaded,
)
end_time = time.perf_counter()
logger.info(
@@ -221,9 +219,8 @@ def load_mlx_items(
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_layer_loaded: LayerLoadedCallback | None,
) -> tuple[nn.Module, TokenizerWrapper]:
group: mx.distributed.Group,
) -> Generator[ModelLoadingResponse, None, tuple[nn.Module, TokenizerWrapper]]:
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
@@ -254,12 +251,10 @@ def shard_and_load(
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group, on_layer_loaded)
model = yield from tensor_auto_parallel(model, group)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(
model, group, shard_metadata, on_layer_loaded=on_layer_loaded
)
model = yield from pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
case CfgShardMetadata():
raise ValueError(
@@ -748,7 +743,9 @@ def set_wired_limit_for_model(model_size: Memory):
def mlx_cleanup(
model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None
model: Model | None,
tokenizer: TokenizerWrapper | None,
group: mx.distributed.Group | None,
) -> None:
del model, tokenizer, group
mx.clear_cache()
@@ -757,7 +754,7 @@ def mlx_cleanup(
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
def mx_any(bool_: bool, group: mx.distributed.Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
@@ -767,7 +764,7 @@ def mx_any(bool_: bool, group: Group | None) -> bool:
return num_true.item() > 0
def mx_barrier(group: Group | None):
def mx_barrier(group: mx.distributed.Group | None):
if group is None:
return
mx.eval(

View File

@@ -1,19 +1,17 @@
import base64
import time
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING
import mlx.core as mx
from exo.api.types import (
ImageEditsTaskParams,
ImageGenerationStats,
ImageGenerationTaskParams,
)
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.constants import EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelTask
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
from exo.shared.types.chunks import ErrorChunk, ImageChunk
from exo.shared.types.common import CommandId, ModelId
from exo.shared.types.chunks import ErrorChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -36,10 +34,6 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -87,32 +81,6 @@ def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
return False
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
image_index: int,
) -> None:
"""Process a single image response and send chunks."""
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
is_partial = isinstance(response, PartialImageResponse)
# Extract stats from final ImageGenerationResponse if available
stats = response.stats if isinstance(response, ImageGenerationResponse) else None
_send_image_chunk(
encoded_data=encoded_data,
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
stats=stats,
image_format=response.format,
)
def _send_traces_if_enabled(
event_sender: MpSender[Event],
task_id: TaskId,
@@ -143,48 +111,6 @@ def _send_traces_if_enabled(
clear_trace_buffer()
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,
model_id: ModelId,
event_sender: MpSender[Event],
image_index: int,
is_partial: bool,
partial_index: int | None = None,
total_partials: int | None = None,
stats: ImageGenerationStats | None = None,
image_format: Literal["png", "jpeg", "webp"] | None = None,
) -> None:
"""Send base64-encoded image data as chunks via events."""
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
for chunk_index, chunk_data in enumerate(data_chunks):
# Only include stats on the last chunk of the final image
chunk_stats = (
stats if chunk_index == total_chunks - 1 and not is_partial else None
)
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=is_partial,
partial_index=partial_index,
total_partials=total_partials,
stats=chunk_stats,
format=image_format,
),
)
)
class Runner:
def __init__(
self,
@@ -261,35 +187,21 @@ class Runner:
return self._check_cancelled(task.task_id)
try:
image_index = 0
for response in generate_image(
for chunk in generate_image(
model=self.image_model,
task=task_params,
cancel_checker=cancel_checker,
):
if _is_primary_output_node(self.shard_metadata):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
image_index += 1
if chunk.is_partial:
logger.info(
f"sending partial ImageChunk {chunk.partial_index}/{chunk.total_partials}"
)
else:
logger.info("sending final ImageChunk")
self.event_sender.send(
ChunkGenerated(command_id=command_id, chunk=chunk)
)
except Exception as e:
if _is_primary_output_node(self.shard_metadata):
self.event_sender.send(

View File

@@ -2,20 +2,20 @@ import itertools
import time
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Generator, Iterable
from collections.abc import Generator, Iterator
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.constants import EXO_MAX_CONCURRENT_REQUESTS
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
from exo.shared.types.chunks import ErrorChunk, GenerationChunk, PrefillProgressChunk
from exo.shared.types.common import ModelId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import CANCEL_ALL_TASKS, TaskId, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.batch_generate import ExoBatchGenerator
@@ -32,7 +32,7 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.engines.mlx.vision import VisionProcessor
from exo.worker.runner.bootstrap import logger
from .model_output_parsers import apply_all_parsers
from .model_output_parsers import apply_all_parsers, map_responses_to_chunks
from .tool_parsers import ToolParser
@@ -80,9 +80,7 @@ class InferenceGenerator(ABC):
@abstractmethod
def step(
self,
) -> Iterable[
tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished]
]: ...
) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: ...
@abstractmethod
def close(self) -> None: ...
@@ -137,7 +135,7 @@ class SequentialGenerator(InferenceGenerator):
# queue that the 1st generator should push to and 3rd generator should pull from
GeneratorQueue[GenerationResponse],
# generator to get parsed outputs
Generator[GenerationResponse | ToolCallResponse | None],
Iterator[GenerationChunk | None],
]
| None
) = field(default=None, init=False)
@@ -183,9 +181,7 @@ class SequentialGenerator(InferenceGenerator):
def step(
self,
) -> Iterable[
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
]:
) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]:
if self._active is None:
self.agree_on_tasks()
@@ -197,9 +193,7 @@ class SequentialGenerator(InferenceGenerator):
assert self._active is not None
task, mlx_gen, queue, output_generator = self._active
output: list[
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
] = []
output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = []
try:
response = next(mlx_gen)
queue.push(response)
@@ -233,7 +227,9 @@ class SequentialGenerator(InferenceGenerator):
queue = GeneratorQueue[GenerationResponse]()
if task.task_params.bench:
output_generator = queue.gen()
output_generator: Iterator[GenerationChunk | None] = map(
lambda r: map_responses_to_chunks(r, self.model_id), queue.gen()
)
else:
output_generator = apply_all_parsers(
queue.gen(),
@@ -338,7 +334,7 @@ class BatchGenerator(InferenceGenerator):
tuple[
TextGeneration,
GeneratorQueue[GenerationResponse],
Generator[GenerationResponse | ToolCallResponse | None],
Iterator[GenerationChunk | None],
],
] = field(default_factory=dict, init=False)
@@ -392,9 +388,7 @@ class BatchGenerator(InferenceGenerator):
def step(
self,
) -> Iterable[
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
]:
) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]:
if not self._queue:
self.agree_on_tasks()
@@ -411,7 +405,9 @@ class BatchGenerator(InferenceGenerator):
queue = GeneratorQueue[GenerationResponse]()
if task.task_params.bench:
output_generator = queue.gen()
output_generator: Iterator[GenerationChunk | None] = map(
lambda r: map_responses_to_chunks(r, self.model_id), queue.gen()
)
else:
output_generator = apply_all_parsers(
queue.gen(),
@@ -429,9 +425,7 @@ class BatchGenerator(InferenceGenerator):
results = self._mlx_gen.step()
output: list[
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
] = []
output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = []
for uid, response in results:
if uid not in self._active_tasks:
# should we error here?
@@ -453,9 +447,9 @@ class BatchGenerator(InferenceGenerator):
def _apply_cancellations(
self,
) -> list[tuple[TaskId, Cancelled]]:
) -> Iterator[tuple[TaskId, Cancelled]]:
if not self._cancelled_tasks:
return []
return iter([])
cancel_all = CANCEL_ALL_TASKS in self._cancelled_tasks
@@ -477,7 +471,7 @@ class BatchGenerator(InferenceGenerator):
results.append((tid, Cancelled()))
self._cancelled_tasks.clear()
return results
return iter(results)
def _send_error(self, task: TextGeneration, e: Exception) -> None:
if self.device_rank == 0:

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from functools import cache
from typing import Any
@@ -14,6 +14,12 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
)
from exo.api.types import ToolCallItem
from exo.shared.types.chunks import (
ErrorChunk,
GenerationChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
@@ -64,29 +70,70 @@ def apply_all_parsers(
model_type: type[Model],
model_id: ModelId,
tools: list[dict[str, Any]] | None,
) -> Generator[GenerationResponse | ToolCallResponse | None]:
mlx_generator = receiver
) -> Iterator[GenerationChunk | None]:
generator = receiver
if issubclass(model_type, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
generator = parse_gpt_oss(generator)
elif (
issubclass(model_type, DeepseekV32Model)
and "deepseek" in model_id.normalize().lower()
):
mlx_generator = parse_deepseek_v32(mlx_generator)
generator = parse_deepseek_v32(generator)
else:
if tokenizer.has_thinking:
mlx_generator = parse_thinking_models(
mlx_generator,
generator = parse_thinking_models(
generator,
tokenizer.think_start,
tokenizer.think_end,
starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer),
)
if tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser, tools)
generator = parse_tool_calls(generator, tool_parser, tools)
return count_reasoning_tokens(mlx_generator)
generator = count_reasoning_tokens(generator)
return map(lambda r: map_responses_to_chunks(r, model_id), generator)
def map_responses_to_chunks(
response: GenerationResponse | ToolCallResponse | None, model_id: ModelId
) -> GenerationChunk | None:
match response:
case None:
return None
case GenerationResponse():
if response.finish_reason == "error":
return ErrorChunk(
error_message=response.text,
model=model_id,
)
else:
finish_reason = response.finish_reason
assert finish_reason not in (
"error",
"tool_calls",
"function_call",
)
return TokenChunk(
model=model_id,
text=response.text,
token_id=response.token,
usage=response.usage,
finish_reason=finish_reason,
stats=response.stats,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
is_thinking=response.is_thinking,
)
case ToolCallResponse():
return ToolCallChunk(
tool_calls=response.tool_calls,
model=model_id,
usage=response.usage,
stats=response.stats,
)
def parse_gpt_oss(

View File

@@ -1,5 +1,6 @@
import os
import time
from collections.abc import Generator
from dataclasses import dataclass
from enum import Enum
@@ -8,11 +9,7 @@ from anyio import WouldBlock
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.chunks import (
ErrorChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, ModelId
from exo.shared.types.events import (
ChunkGenerated,
@@ -34,8 +31,7 @@ from exo.shared.types.tasks import (
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
ModelLoadingResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -181,23 +177,28 @@ class Runner:
)
self.acknowledge_task(task)
def on_layer_loaded(layers_loaded: int, total: int) -> None:
self.update_status(
RunnerLoading(layers_loaded=layers_loaded, total_layers=total)
)
assert (
ModelTask.TextGeneration in self.shard_metadata.model_card.tasks
), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}"
(
self.generator.inference_model,
self.generator.tokenizer,
self.generator.vision_processor,
) = load_mlx_items(
self.bound_instance,
self.generator.group,
on_layer_loaded=on_layer_loaded,
)
def load_model() -> Generator[ModelLoadingResponse]:
assert isinstance(self.generator, Builder)
(
self.generator.inference_model,
self.generator.tokenizer,
self.generator.vision_processor,
) = yield from load_mlx_items(
self.bound_instance,
self.generator.group,
)
for load_resp in load_model():
self.update_status(
RunnerLoading(
layers_loaded=load_resp.layers_loaded,
total_layers=load_resp.total,
)
)
self.generator = self.generator.build()
@@ -278,9 +279,7 @@ class Runner:
self.send_task_status(task_id, TaskStatus.Complete)
finished.append(task_id)
case _:
self.send_response(
result, self.active_tasks[task_id].command_id
)
self.send_chunk(result, self.active_tasks[task_id].command_id)
for task_id in finished:
self.active_tasks.pop(task_id, None)
@@ -313,59 +312,13 @@ class Runner:
return ExitCode.AllTasksComplete
def send_response(
def send_chunk(
self,
response: GenerationResponse | ToolCallResponse,
chunk: GenerationChunk,
command_id: CommandId,
):
match response:
case GenerationResponse():
if self.device_rank == 0 and response.finish_reason == "error":
self.event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
error_message=response.text,
model=self.model_id,
),
)
)
elif self.device_rank == 0:
assert response.finish_reason not in (
"error",
"tool_calls",
"function_call",
)
self.event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
model=self.model_id,
text=response.text,
token_id=response.token,
usage=response.usage,
finish_reason=response.finish_reason,
stats=response.stats,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
is_thinking=response.is_thinking,
),
)
)
case ToolCallResponse():
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=self.model_id,
usage=response.usage,
stats=response.stats,
),
)
)
if self.device_rank == 0:
self.event_sender.send(ChunkGenerated(command_id=command_id, chunk=chunk))
@dataclass

View File

@@ -1,6 +1,7 @@
# Check tasks are complete before runner is ever ready.
import unittest.mock
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Callable
import mlx.core as mx
@@ -115,13 +116,22 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
assert test_event == true_event, f"{test_event} != {true_event}"
@dataclass
class MockLoadOutput:
layers_loaded: int
total: int
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer, None))
)
def lmi_gen():
yield MockLoadOutput(1, 1)
return (1, MockTokenizer, None)
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin(lmi_gen()))
monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
@@ -318,6 +328,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
runner_status=RunnerLoading(layers_loaded=0, total_layers=32),
),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID,
runner_status=RunnerLoading(layers_loaded=1, total_layers=1),
),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),