mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-19 04:05:23 -04:00
remove layer loading callback (#1890)
first part of modularising the backend is simplifying some of the control flow. more tbd.
This commit is contained in:
@@ -70,6 +70,11 @@ class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
|
||||
class ModelLoadingResponse(BaseRunnerResponse):
|
||||
layers_loaded: int
|
||||
total: int
|
||||
|
||||
|
||||
class PrefillProgressResponse(BaseRunnerResponse):
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
@@ -8,6 +8,7 @@ from PIL import Image
|
||||
|
||||
from exo.api.types import AdvancedImageParams
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
@@ -22,13 +23,14 @@ from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class DistributedImageModel:
|
||||
model_id: ModelId
|
||||
_config: ImageModelConfig
|
||||
_adapter: ModelAdapter[Any, Any]
|
||||
_runner: DiffusionRunner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_id: ModelId,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
@@ -68,6 +70,7 @@ class DistributedImageModel:
|
||||
else:
|
||||
logger.info("Single-node initialization")
|
||||
|
||||
self.model_id = model_id
|
||||
self._config = config
|
||||
self._adapter = adapter
|
||||
self._runner = runner
|
||||
|
||||
@@ -3,7 +3,7 @@ import io
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
@@ -17,11 +17,10 @@ from exo.api.types import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageSize,
|
||||
)
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.types.chunks import ImageChunk
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.worker.engines.image.distributed_model import DistributedImageModel
|
||||
|
||||
|
||||
@@ -71,16 +70,8 @@ def generate_image(
|
||||
model: DistributedImageModel,
|
||||
task: ImageGenerationTaskParams | ImageEditsTaskParams,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
"""
|
||||
) -> Generator[ImageChunk, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results."""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
@@ -142,12 +133,14 @@ def generate_image(
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield PartialImageResponse(
|
||||
yield from _process_image_response(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
image_format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
image_index=image_num,
|
||||
model_id=model.model_id,
|
||||
stats=None,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
@@ -189,9 +182,54 @@ def generate_image(
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
yield from _process_image_response(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
image_format=task.output_format,
|
||||
stats=stats,
|
||||
image_index=image_num,
|
||||
model_id=model.model_id,
|
||||
partial_index=None,
|
||||
total_partials=None,
|
||||
)
|
||||
|
||||
|
||||
def _process_image_response(
|
||||
image_data: bytes,
|
||||
image_index: int,
|
||||
image_format: Literal["png", "jpeg", "webp"],
|
||||
partial_index: int | None,
|
||||
total_partials: int | None,
|
||||
stats: ImageGenerationStats | None,
|
||||
model_id: ModelId,
|
||||
) -> Iterator[ImageChunk]:
|
||||
"""Process a single image response and send chunks."""
|
||||
is_partial = partial_index is not None
|
||||
encoded_data = base64.b64encode(image_data).decode("utf-8")
|
||||
# Extract stats from final ImageGenerationResponse if available
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
def _data_to_chunk(item: tuple[int, str]) -> ImageChunk:
|
||||
chunk_index, chunk_data = item
|
||||
# Only include stats on the last chunk of the final image
|
||||
chunk_stats = (
|
||||
stats if chunk_index == total_chunks - 1 and not is_partial else None
|
||||
)
|
||||
|
||||
return ImageChunk(
|
||||
model=model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=partial_index,
|
||||
total_partials=total_partials,
|
||||
stats=chunk_stats,
|
||||
format=image_format,
|
||||
)
|
||||
|
||||
return map(_data_to_chunk, enumerate(data_chunks))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Generator
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Literal, Protocol, cast
|
||||
@@ -59,14 +59,13 @@ from mlx_lm.models.step3p5 import Model as Step35Model
|
||||
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
|
||||
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
|
||||
|
||||
from exo.shared.types.worker.runner_response import ModelLoadingResponse
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
|
||||
|
||||
|
||||
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
|
||||
|
||||
@@ -276,8 +275,7 @@ def pipeline_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
model_shard_meta: PipelineShardMetadata,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
"""
|
||||
Automatically parallelize a model across multiple devices.
|
||||
Args:
|
||||
@@ -297,8 +295,7 @@ def pipeline_auto_parallel(
|
||||
total = len(layers)
|
||||
for i, layer in enumerate(layers):
|
||||
mx.eval(layer) # type: ignore
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
@@ -460,8 +457,7 @@ def patch_tensor_model[T](model: T) -> T:
|
||||
def tensor_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
all_to_sharded_linear = partial(
|
||||
shard_linear,
|
||||
sharding="all-to-sharded",
|
||||
@@ -595,7 +591,7 @@ def tensor_auto_parallel(
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
model = tensor_parallel_sharding_strategy.shard_model(model, on_layer_loaded)
|
||||
model = yield from tensor_parallel_sharding_strategy.shard_model(model)
|
||||
return patch_tensor_model(model)
|
||||
|
||||
|
||||
@@ -619,16 +615,14 @@ class TensorParallelShardingStrategy(ABC):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module: ...
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]: ...
|
||||
|
||||
|
||||
class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(LlamaModel, model)
|
||||
total = len(model.layers)
|
||||
for i, layer in enumerate(model.layers):
|
||||
@@ -646,8 +640,8 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
|
||||
@@ -681,8 +675,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
total = len(model.layers)
|
||||
|
||||
@@ -738,8 +731,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
|
||||
return model
|
||||
|
||||
@@ -764,8 +757,7 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
total = len(model.layers) # type: ignore
|
||||
for i, layer in enumerate(model.layers): # type: ignore
|
||||
@@ -816,8 +808,8 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # type: ignore
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
|
||||
return model
|
||||
|
||||
@@ -904,8 +896,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(MiniMaxModel, model)
|
||||
total = len(model.layers)
|
||||
for i, layer in enumerate(model.layers):
|
||||
@@ -934,8 +925,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.block_sparse_moe = ShardedMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
|
||||
@@ -943,8 +934,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(
|
||||
Qwen3Model
|
||||
| Qwen3MoeModel
|
||||
@@ -1099,8 +1089,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
|
||||
@@ -1108,8 +1098,7 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(Glm4MoeModel, model)
|
||||
total = len(model.layers)
|
||||
for i, layer in enumerate(model.layers):
|
||||
@@ -1145,8 +1134,8 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
|
||||
@@ -1154,8 +1143,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
total = len(model.layers)
|
||||
|
||||
@@ -1186,8 +1174,8 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
|
||||
@@ -1195,8 +1183,7 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(Step35Model, model)
|
||||
total = len(model.layers)
|
||||
|
||||
@@ -1229,8 +1216,8 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
|
||||
@@ -1238,8 +1225,7 @@ class NemotronHShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(NemotronHModel, model)
|
||||
rank = self.group.rank()
|
||||
total = len(model.layers)
|
||||
@@ -1272,8 +1258,7 @@ class NemotronHShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mixer = mixer # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
def _shard_mamba2_mixer(self, mixer: NemotronHMamba2Mixer, rank: int) -> None:
|
||||
@@ -1380,8 +1365,7 @@ class Gemma4ShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> nn.Module:
|
||||
) -> Generator[ModelLoadingResponse, None, nn.Module]:
|
||||
model = cast(Gemma4Model, model)
|
||||
layers = model.language_model.model.layers
|
||||
total = len(layers)
|
||||
@@ -1409,6 +1393,5 @@ class Gemma4ShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.experts.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
return model
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
@@ -51,6 +52,7 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import ModelLoadingResponse
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
@@ -58,7 +60,6 @@ from exo.shared.types.worker.shards import (
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
LayerLoadedCallback,
|
||||
get_inner_model,
|
||||
get_layers,
|
||||
pipeline_auto_parallel,
|
||||
@@ -66,8 +67,6 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
|
||||
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
@@ -90,7 +89,7 @@ class HostList(RootModel[list[str]]):
|
||||
|
||||
def mlx_distributed_init(
|
||||
bound_instance: BoundInstance,
|
||||
) -> Group:
|
||||
) -> mx.distributed.Group:
|
||||
"""
|
||||
Initialize MLX distributed.
|
||||
"""
|
||||
@@ -149,7 +148,7 @@ def mlx_distributed_init(
|
||||
|
||||
def initialize_mlx(
|
||||
bound_instance: BoundInstance,
|
||||
) -> Group:
|
||||
) -> mx.distributed.Group:
|
||||
# should we unseed it?
|
||||
# TODO: pass in seed from params
|
||||
mx.random.seed(42)
|
||||
@@ -162,9 +161,10 @@ def initialize_mlx(
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> "tuple[Model, TokenizerWrapper, VisionProcessor | None]":
|
||||
group: mx.distributed.Group | None,
|
||||
) -> Generator[
|
||||
ModelLoadingResponse, None, tuple[Model, TokenizerWrapper, "VisionProcessor | None"]
|
||||
]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
@@ -177,8 +177,7 @@ def load_mlx_items(
|
||||
total = len(layers)
|
||||
for i, layer in enumerate(layers):
|
||||
mx.eval(layer) # type: ignore
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
yield ModelLoadingResponse(layers_loaded=i, total=total)
|
||||
except ValueError as e:
|
||||
logger.opt(exception=e).debug(
|
||||
"Model architecture doesn't support layer-by-layer progress tracking",
|
||||
@@ -191,10 +190,9 @@ def load_mlx_items(
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
model, tokenizer = yield from shard_and_load(
|
||||
bound_instance.bound_shard,
|
||||
group=group,
|
||||
on_layer_loaded=on_layer_loaded,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
@@ -221,9 +219,8 @@ def load_mlx_items(
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
group: mx.distributed.Group,
|
||||
) -> Generator[ModelLoadingResponse, None, tuple[nn.Module, TokenizerWrapper]]:
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
@@ -254,12 +251,10 @@ def shard_and_load(
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(model, group, on_layer_loaded)
|
||||
model = yield from tensor_auto_parallel(model, group)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(
|
||||
model, group, shard_metadata, on_layer_loaded=on_layer_loaded
|
||||
)
|
||||
model = yield from pipeline_auto_parallel(model, group, shard_metadata)
|
||||
mx.eval(model.parameters())
|
||||
case CfgShardMetadata():
|
||||
raise ValueError(
|
||||
@@ -748,7 +743,9 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
|
||||
|
||||
def mlx_cleanup(
|
||||
model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None
|
||||
model: Model | None,
|
||||
tokenizer: TokenizerWrapper | None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> None:
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
@@ -757,7 +754,7 @@ def mlx_cleanup(
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
def mx_any(bool_: bool, group: mx.distributed.Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
@@ -767,7 +764,7 @@ def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
def mx_barrier(group: mx.distributed.Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import base64
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.api.types import (
|
||||
ImageEditsTaskParams,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.constants import EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
from exo.shared.types.chunks import ErrorChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -36,10 +34,6 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
@@ -87,32 +81,6 @@ def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _process_image_response(
|
||||
response: ImageGenerationResponse | PartialImageResponse,
|
||||
command_id: CommandId,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
image_index: int,
|
||||
) -> None:
|
||||
"""Process a single image response and send chunks."""
|
||||
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
|
||||
is_partial = isinstance(response, PartialImageResponse)
|
||||
# Extract stats from final ImageGenerationResponse if available
|
||||
stats = response.stats if isinstance(response, ImageGenerationResponse) else None
|
||||
_send_image_chunk(
|
||||
encoded_data=encoded_data,
|
||||
command_id=command_id,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
event_sender=event_sender,
|
||||
image_index=response.image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=response.partial_index if is_partial else None,
|
||||
total_partials=response.total_partials if is_partial else None,
|
||||
stats=stats,
|
||||
image_format=response.format,
|
||||
)
|
||||
|
||||
|
||||
def _send_traces_if_enabled(
|
||||
event_sender: MpSender[Event],
|
||||
task_id: TaskId,
|
||||
@@ -143,48 +111,6 @@ def _send_traces_if_enabled(
|
||||
clear_trace_buffer()
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
encoded_data: str,
|
||||
command_id: CommandId,
|
||||
model_id: ModelId,
|
||||
event_sender: MpSender[Event],
|
||||
image_index: int,
|
||||
is_partial: bool,
|
||||
partial_index: int | None = None,
|
||||
total_partials: int | None = None,
|
||||
stats: ImageGenerationStats | None = None,
|
||||
image_format: Literal["png", "jpeg", "webp"] | None = None,
|
||||
) -> None:
|
||||
"""Send base64-encoded image data as chunks via events."""
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
for chunk_index, chunk_data in enumerate(data_chunks):
|
||||
# Only include stats on the last chunk of the final image
|
||||
chunk_stats = (
|
||||
stats if chunk_index == total_chunks - 1 and not is_partial else None
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
model=model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=partial_index,
|
||||
total_partials=total_partials,
|
||||
stats=chunk_stats,
|
||||
format=image_format,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -261,35 +187,21 @@ class Runner:
|
||||
return self._check_cancelled(task.task_id)
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
for chunk in generate_image(
|
||||
model=self.image_model,
|
||||
task=task_params,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if _is_primary_output_node(self.shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
self.shard_metadata,
|
||||
self.event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
self.shard_metadata,
|
||||
self.event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
if chunk.is_partial:
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {chunk.partial_index}/{chunk.total_partials}"
|
||||
)
|
||||
else:
|
||||
logger.info("sending final ImageChunk")
|
||||
self.event_sender.send(
|
||||
ChunkGenerated(command_id=command_id, chunk=chunk)
|
||||
)
|
||||
except Exception as e:
|
||||
if _is_primary_output_node(self.shard_metadata):
|
||||
self.event_sender.send(
|
||||
|
||||
@@ -2,20 +2,20 @@ import itertools
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Generator, Iterable
|
||||
from collections.abc import Generator, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CONCURRENT_REQUESTS
|
||||
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
|
||||
from exo.shared.types.chunks import ErrorChunk, GenerationChunk, PrefillProgressChunk
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.tasks import CANCEL_ALL_TASKS, TaskId, TextGeneration
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.batch_generate import ExoBatchGenerator
|
||||
@@ -32,7 +32,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
from exo.worker.engines.mlx.vision import VisionProcessor
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
from .model_output_parsers import apply_all_parsers
|
||||
from .model_output_parsers import apply_all_parsers, map_responses_to_chunks
|
||||
from .tool_parsers import ToolParser
|
||||
|
||||
|
||||
@@ -80,9 +80,7 @@ class InferenceGenerator(ABC):
|
||||
@abstractmethod
|
||||
def step(
|
||||
self,
|
||||
) -> Iterable[
|
||||
tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished]
|
||||
]: ...
|
||||
) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: ...
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None: ...
|
||||
@@ -137,7 +135,7 @@ class SequentialGenerator(InferenceGenerator):
|
||||
# queue that the 1st generator should push to and 3rd generator should pull from
|
||||
GeneratorQueue[GenerationResponse],
|
||||
# generator to get parsed outputs
|
||||
Generator[GenerationResponse | ToolCallResponse | None],
|
||||
Iterator[GenerationChunk | None],
|
||||
]
|
||||
| None
|
||||
) = field(default=None, init=False)
|
||||
@@ -183,9 +181,7 @@ class SequentialGenerator(InferenceGenerator):
|
||||
|
||||
def step(
|
||||
self,
|
||||
) -> Iterable[
|
||||
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
|
||||
]:
|
||||
) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]:
|
||||
if self._active is None:
|
||||
self.agree_on_tasks()
|
||||
|
||||
@@ -197,9 +193,7 @@ class SequentialGenerator(InferenceGenerator):
|
||||
assert self._active is not None
|
||||
|
||||
task, mlx_gen, queue, output_generator = self._active
|
||||
output: list[
|
||||
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
|
||||
] = []
|
||||
output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = []
|
||||
try:
|
||||
response = next(mlx_gen)
|
||||
queue.push(response)
|
||||
@@ -233,7 +227,9 @@ class SequentialGenerator(InferenceGenerator):
|
||||
queue = GeneratorQueue[GenerationResponse]()
|
||||
|
||||
if task.task_params.bench:
|
||||
output_generator = queue.gen()
|
||||
output_generator: Iterator[GenerationChunk | None] = map(
|
||||
lambda r: map_responses_to_chunks(r, self.model_id), queue.gen()
|
||||
)
|
||||
else:
|
||||
output_generator = apply_all_parsers(
|
||||
queue.gen(),
|
||||
@@ -338,7 +334,7 @@ class BatchGenerator(InferenceGenerator):
|
||||
tuple[
|
||||
TextGeneration,
|
||||
GeneratorQueue[GenerationResponse],
|
||||
Generator[GenerationResponse | ToolCallResponse | None],
|
||||
Iterator[GenerationChunk | None],
|
||||
],
|
||||
] = field(default_factory=dict, init=False)
|
||||
|
||||
@@ -392,9 +388,7 @@ class BatchGenerator(InferenceGenerator):
|
||||
|
||||
def step(
|
||||
self,
|
||||
) -> Iterable[
|
||||
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
|
||||
]:
|
||||
) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]:
|
||||
if not self._queue:
|
||||
self.agree_on_tasks()
|
||||
|
||||
@@ -411,7 +405,9 @@ class BatchGenerator(InferenceGenerator):
|
||||
|
||||
queue = GeneratorQueue[GenerationResponse]()
|
||||
if task.task_params.bench:
|
||||
output_generator = queue.gen()
|
||||
output_generator: Iterator[GenerationChunk | None] = map(
|
||||
lambda r: map_responses_to_chunks(r, self.model_id), queue.gen()
|
||||
)
|
||||
else:
|
||||
output_generator = apply_all_parsers(
|
||||
queue.gen(),
|
||||
@@ -429,9 +425,7 @@ class BatchGenerator(InferenceGenerator):
|
||||
|
||||
results = self._mlx_gen.step()
|
||||
|
||||
output: list[
|
||||
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
|
||||
] = []
|
||||
output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = []
|
||||
for uid, response in results:
|
||||
if uid not in self._active_tasks:
|
||||
# should we error here?
|
||||
@@ -453,9 +447,9 @@ class BatchGenerator(InferenceGenerator):
|
||||
|
||||
def _apply_cancellations(
|
||||
self,
|
||||
) -> list[tuple[TaskId, Cancelled]]:
|
||||
) -> Iterator[tuple[TaskId, Cancelled]]:
|
||||
if not self._cancelled_tasks:
|
||||
return []
|
||||
return iter([])
|
||||
|
||||
cancel_all = CANCEL_ALL_TASKS in self._cancelled_tasks
|
||||
|
||||
@@ -477,7 +471,7 @@ class BatchGenerator(InferenceGenerator):
|
||||
results.append((tid, Cancelled()))
|
||||
|
||||
self._cancelled_tasks.clear()
|
||||
return results
|
||||
return iter(results)
|
||||
|
||||
def _send_error(self, task: TextGeneration, e: Exception) -> None:
|
||||
if self.device_rank == 0:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
|
||||
@@ -14,6 +14,12 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
)
|
||||
|
||||
from exo.api.types import ToolCallItem
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
GenerationChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.mlx import Model
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
@@ -64,29 +70,70 @@ def apply_all_parsers(
|
||||
model_type: type[Model],
|
||||
model_id: ModelId,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||
mlx_generator = receiver
|
||||
) -> Iterator[GenerationChunk | None]:
|
||||
generator = receiver
|
||||
|
||||
if issubclass(model_type, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
generator = parse_gpt_oss(generator)
|
||||
elif (
|
||||
issubclass(model_type, DeepseekV32Model)
|
||||
and "deepseek" in model_id.normalize().lower()
|
||||
):
|
||||
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||
generator = parse_deepseek_v32(generator)
|
||||
else:
|
||||
if tokenizer.has_thinking:
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator,
|
||||
generator = parse_thinking_models(
|
||||
generator,
|
||||
tokenizer.think_start,
|
||||
tokenizer.think_end,
|
||||
starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer),
|
||||
)
|
||||
|
||||
if tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser, tools)
|
||||
generator = parse_tool_calls(generator, tool_parser, tools)
|
||||
|
||||
return count_reasoning_tokens(mlx_generator)
|
||||
generator = count_reasoning_tokens(generator)
|
||||
|
||||
return map(lambda r: map_responses_to_chunks(r, model_id), generator)
|
||||
|
||||
|
||||
def map_responses_to_chunks(
|
||||
response: GenerationResponse | ToolCallResponse | None, model_id: ModelId
|
||||
) -> GenerationChunk | None:
|
||||
match response:
|
||||
case None:
|
||||
return None
|
||||
case GenerationResponse():
|
||||
if response.finish_reason == "error":
|
||||
return ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=model_id,
|
||||
)
|
||||
else:
|
||||
finish_reason = response.finish_reason
|
||||
assert finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
return TokenChunk(
|
||||
model=model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
usage=response.usage,
|
||||
finish_reason=finish_reason,
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
is_thinking=response.is_thinking,
|
||||
)
|
||||
case ToolCallResponse():
|
||||
return ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=model_id,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
@@ -8,11 +9,7 @@ from anyio import WouldBlock
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelTask
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -34,8 +31,7 @@ from exo.shared.types.tasks import (
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
ModelLoadingResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -181,23 +177,28 @@ class Runner:
|
||||
)
|
||||
self.acknowledge_task(task)
|
||||
|
||||
def on_layer_loaded(layers_loaded: int, total: int) -> None:
|
||||
self.update_status(
|
||||
RunnerLoading(layers_loaded=layers_loaded, total_layers=total)
|
||||
)
|
||||
|
||||
assert (
|
||||
ModelTask.TextGeneration in self.shard_metadata.model_card.tasks
|
||||
), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}"
|
||||
(
|
||||
self.generator.inference_model,
|
||||
self.generator.tokenizer,
|
||||
self.generator.vision_processor,
|
||||
) = load_mlx_items(
|
||||
self.bound_instance,
|
||||
self.generator.group,
|
||||
on_layer_loaded=on_layer_loaded,
|
||||
)
|
||||
|
||||
def load_model() -> Generator[ModelLoadingResponse]:
|
||||
assert isinstance(self.generator, Builder)
|
||||
(
|
||||
self.generator.inference_model,
|
||||
self.generator.tokenizer,
|
||||
self.generator.vision_processor,
|
||||
) = yield from load_mlx_items(
|
||||
self.bound_instance,
|
||||
self.generator.group,
|
||||
)
|
||||
|
||||
for load_resp in load_model():
|
||||
self.update_status(
|
||||
RunnerLoading(
|
||||
layers_loaded=load_resp.layers_loaded,
|
||||
total_layers=load_resp.total,
|
||||
)
|
||||
)
|
||||
|
||||
self.generator = self.generator.build()
|
||||
|
||||
@@ -278,9 +279,7 @@ class Runner:
|
||||
self.send_task_status(task_id, TaskStatus.Complete)
|
||||
finished.append(task_id)
|
||||
case _:
|
||||
self.send_response(
|
||||
result, self.active_tasks[task_id].command_id
|
||||
)
|
||||
self.send_chunk(result, self.active_tasks[task_id].command_id)
|
||||
|
||||
for task_id in finished:
|
||||
self.active_tasks.pop(task_id, None)
|
||||
@@ -313,59 +312,13 @@ class Runner:
|
||||
|
||||
return ExitCode.AllTasksComplete
|
||||
|
||||
def send_response(
|
||||
def send_chunk(
|
||||
self,
|
||||
response: GenerationResponse | ToolCallResponse,
|
||||
chunk: GenerationChunk,
|
||||
command_id: CommandId,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if self.device_rank == 0 and response.finish_reason == "error":
|
||||
self.event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=self.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif self.device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
self.event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
model=self.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
is_thinking=response.is_thinking,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if self.device_rank == 0:
|
||||
self.event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=self.model_id,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
if self.device_rank == 0:
|
||||
self.event_sender.send(ChunkGenerated(command_id=command_id, chunk=chunk))
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
import unittest.mock
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -115,13 +116,22 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockLoadOutput:
|
||||
layers_loaded: int
|
||||
total: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer, None))
|
||||
)
|
||||
|
||||
def lmi_gen():
|
||||
yield MockLoadOutput(1, 1)
|
||||
return (1, MockTokenizer, None)
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin(lmi_gen()))
|
||||
monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
|
||||
@@ -318,6 +328,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
runner_status=RunnerLoading(layers_loaded=0, total_layers=32),
|
||||
),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID,
|
||||
runner_status=RunnerLoading(layers_loaded=1, total_layers=1),
|
||||
),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
|
||||
Reference in New Issue
Block a user