diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index cda11ffaa..76bd6fd4e 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -2,6 +2,7 @@ from enum import Enum from pydantic import model_validator +from exo.shared.models.model_cards import ModelTask from exo.shared.types.common import Host, Id, NodeId from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel @@ -49,6 +50,13 @@ class BoundInstance(CamelCaseModel): assert shard is not None return shard + @property + def is_image_model(self) -> bool: + return ( + ModelTask.TextToImage in self.bound_shard.model_card.tasks + or ModelTask.ImageToImage in self.bound_shard.model_card.tasks + ) + @model_validator(mode="after") def validate_shard_exists(self) -> "BoundInstance": assert ( diff --git a/src/exo/worker/runner/__init__.py b/src/exo/worker/runner/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index ed420aab3..9949cb7e6 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -37,9 +37,13 @@ def entrypoint( # Import main after setting global logger - this lets us just import logger from this module try: - from exo.worker.runner.runner import main + if bound_instance.is_image_model: + from exo.worker.runner.image_models.runner import main + else: + from exo.worker.runner.llm_inference.runner import main main(bound_instance, event_sender, task_receiver, cancel_receiver) + except ClosedResourceError: logger.warning("Runner communication closed unexpectedly") except Exception as e: diff --git a/src/exo/worker/runner/image_models/__init__.py b/src/exo/worker/runner/image_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/worker/runner/image_models/runner.py b/src/exo/worker/runner/image_models/runner.py new file mode 100644 index 000000000..d950be9eb --- /dev/null +++ b/src/exo/worker/runner/image_models/runner.py @@ -0,0 +1,453 @@ +import base64 +import resource +import time +from typing import TYPE_CHECKING, Literal + +import mlx.core as mx + +from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED +from exo.shared.models.model_cards import ModelTask +from exo.shared.tracing import clear_trace_buffer, get_trace_buffer +from exo.shared.types.api import ImageGenerationStats +from exo.shared.types.chunks import ErrorChunk, ImageChunk +from exo.shared.types.common import CommandId, ModelId +from exo.shared.types.events import ( + ChunkGenerated, + Event, + RunnerStatusUpdated, + TaskAcknowledged, + TaskStatusUpdated, + TraceEventData, + TracesCollected, +) +from exo.shared.types.tasks import ( + ConnectToGroup, + ImageEdits, + ImageGeneration, + LoadModel, + Shutdown, + StartWarmup, + Task, + TaskId, + TaskStatus, +) +from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.runner_response import ( + ImageGenerationResponse, + PartialImageResponse, +) +from exo.shared.types.worker.runners import ( + RunnerConnected, + RunnerConnecting, + RunnerFailed, + RunnerIdle, + RunnerLoaded, + RunnerLoading, + RunnerReady, + RunnerRunning, + RunnerShutdown, + RunnerShuttingDown, + RunnerStatus, + RunnerWarmingUp, +) +from exo.shared.types.worker.shards import ( + CfgShardMetadata, + PipelineShardMetadata, + ShardMetadata, +) +from exo.utils.channels import MpReceiver, MpSender +from exo.worker.engines.image import ( + DistributedImageModel, + generate_image, + initialize_image_model, + warmup_image_generator, +) +from exo.worker.engines.mlx.utils_mlx import ( + initialize_mlx, +) +from exo.worker.runner.bootstrap import logger + + +def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool: + """Check if this node is the primary output node for image generation. + + For CFG models: the last pipeline stage in CFG group 0 (positive prompt). + For non-CFG models: the last pipeline stage. + """ + if isinstance(shard_metadata, CfgShardMetadata): + is_pipeline_last = ( + shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1 + ) + return is_pipeline_last and shard_metadata.cfg_rank == 0 + elif isinstance(shard_metadata, PipelineShardMetadata): + return shard_metadata.device_rank == shard_metadata.world_size - 1 + return False + + +def _process_image_response( + response: ImageGenerationResponse | PartialImageResponse, + command_id: CommandId, + shard_metadata: ShardMetadata, + event_sender: MpSender[Event], + image_index: int, +) -> None: + """Process a single image response and send chunks.""" + encoded_data = base64.b64encode(response.image_data).decode("utf-8") + is_partial = isinstance(response, PartialImageResponse) + # Extract stats from final ImageGenerationResponse if available + stats = response.stats if isinstance(response, ImageGenerationResponse) else None + _send_image_chunk( + encoded_data=encoded_data, + command_id=command_id, + model_id=shard_metadata.model_card.model_id, + event_sender=event_sender, + image_index=response.image_index, + is_partial=is_partial, + partial_index=response.partial_index if is_partial else None, + total_partials=response.total_partials if is_partial else None, + stats=stats, + image_format=response.format, + ) + + +def _send_traces_if_enabled( + event_sender: MpSender[Event], + task_id: TaskId, + rank: int, +) -> None: + if not EXO_TRACING_ENABLED: + return + + traces = get_trace_buffer() + if traces: + trace_data = [ + TraceEventData( + name=t.name, + start_us=t.start_us, + duration_us=t.duration_us, + rank=t.rank, + category=t.category, + ) + for t in traces + ] + event_sender.send( + TracesCollected( + task_id=task_id, + rank=rank, + traces=trace_data, + ) + ) + clear_trace_buffer() + + +def _send_image_chunk( + encoded_data: str, + command_id: CommandId, + model_id: ModelId, + event_sender: MpSender[Event], + image_index: int, + is_partial: bool, + partial_index: int | None = None, + total_partials: int | None = None, + stats: ImageGenerationStats | None = None, + image_format: Literal["png", "jpeg", "webp"] | None = None, +) -> None: + """Send base64-encoded image data as chunks via events.""" + data_chunks = [ + encoded_data[i : i + EXO_MAX_CHUNK_SIZE] + for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE) + ] + total_chunks = len(data_chunks) + for chunk_index, chunk_data in enumerate(data_chunks): + # Only include stats on the last chunk of the final image + chunk_stats = ( + stats if chunk_index == total_chunks - 1 and not is_partial else None + ) + event_sender.send( + ChunkGenerated( + command_id=command_id, + chunk=ImageChunk( + model=model_id, + data=chunk_data, + chunk_index=chunk_index, + total_chunks=total_chunks, + image_index=image_index, + is_partial=is_partial, + partial_index=partial_index, + total_partials=total_partials, + stats=chunk_stats, + format=image_format, + ), + ) + ) + + +def main( + bound_instance: BoundInstance, + event_sender: MpSender[Event], + task_receiver: MpReceiver[Task], + cancel_receiver: MpReceiver[TaskId], +): + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard)) + + instance, runner_id, shard_metadata = ( + bound_instance.instance, + bound_instance.bound_runner_id, + bound_instance.bound_shard, + ) + device_rank = shard_metadata.device_rank + logger.info("hello from the runner") + if getattr(shard_metadata, "immediate_exception", False): + raise Exception("Fake exception - runner failed to spin up.") + if timeout := getattr(shard_metadata, "should_timeout", 0): + time.sleep(timeout) + + setup_start_time = time.time() + cancelled_tasks = set[TaskId]() + + image_model: DistributedImageModel | None = None + group = None + + current_status: RunnerStatus = RunnerIdle() + logger.info("runner created") + event_sender.send( + RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) + ) + seen = set[TaskId]() + with task_receiver as tasks: + for task in tasks: + if task.task_id in seen: + logger.warning("repeat task - potential error") + seen.add(task.task_id) + cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK")) + event_sender.send( + TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running) + ) + match task: + case ConnectToGroup() if isinstance( + current_status, (RunnerIdle, RunnerFailed) + ): + logger.info("runner connecting") + current_status = RunnerConnecting() + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + group = initialize_mlx(bound_instance) + + logger.info("runner connected") + current_status = RunnerConnected() + + # we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to + case LoadModel() if ( + isinstance(current_status, RunnerConnected) and group is not None + ) or (isinstance(current_status, RunnerIdle) and group is None): + current_status = RunnerLoading() + logger.info("runner loading") + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + assert ( + ModelTask.TextToImage in shard_metadata.model_card.tasks + or ModelTask.ImageToImage in shard_metadata.model_card.tasks + ), f"Incorrect model task(s): {shard_metadata.model_card.tasks}" + + image_model = initialize_image_model(bound_instance) + current_status = RunnerLoaded() + logger.info("runner loaded") + + case StartWarmup() if isinstance(current_status, RunnerLoaded): + current_status = RunnerWarmingUp() + logger.info("runner warming up") + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + logger.info(f"warming up inference for instance: {instance}") + + assert image_model + image = warmup_image_generator(model=image_model) + if image is not None: + logger.info(f"warmed up by generating {image.size} image") + else: + logger.info("warmup completed (non-primary node)") + + logger.info( + f"runner initialized in {time.time() - setup_start_time} seconds" + ) + + current_status = RunnerReady() + logger.info("runner ready") + + case ImageGeneration( + task_params=task_params, command_id=command_id + ) if isinstance(current_status, RunnerReady): + assert image_model + logger.info(f"received image generation request: {str(task)[:500]}") + current_status = RunnerRunning() + logger.info("runner running") + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + try: + image_index = 0 + for response in generate_image( + model=image_model, task=task_params + ): + is_primary_output = _is_primary_output_node(shard_metadata) + + if is_primary_output: + match response: + case PartialImageResponse(): + logger.info( + f"sending partial ImageChunk {response.partial_index}/{response.total_partials}" + ) + _process_image_response( + response, + command_id, + shard_metadata, + event_sender, + image_index, + ) + case ImageGenerationResponse(): + logger.info("sending final ImageChunk") + _process_image_response( + response, + command_id, + shard_metadata, + event_sender, + image_index, + ) + image_index += 1 + # can we make this more explicit? + except Exception as e: + if _is_primary_output_node(shard_metadata): + event_sender.send( + ChunkGenerated( + command_id=command_id, + chunk=ErrorChunk( + model=shard_metadata.model_card.model_id, + finish_reason="error", + error_message=str(e), + ), + ) + ) + raise + finally: + _send_traces_if_enabled(event_sender, task.task_id, device_rank) + + current_status = RunnerReady() + logger.info("runner ready") + + case ImageEdits(task_params=task_params, command_id=command_id) if ( + isinstance(current_status, RunnerReady) + ): + assert image_model + logger.info(f"received image edits request: {str(task)[:500]}") + current_status = RunnerRunning() + logger.info("runner running") + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + try: + image_index = 0 + for response in generate_image( + model=image_model, task=task_params + ): + if _is_primary_output_node(shard_metadata): + match response: + case PartialImageResponse(): + logger.info( + f"sending partial ImageChunk {response.partial_index}/{response.total_partials}" + ) + _process_image_response( + response, + command_id, + shard_metadata, + event_sender, + image_index, + ) + case ImageGenerationResponse(): + logger.info("sending final ImageChunk") + _process_image_response( + response, + command_id, + shard_metadata, + event_sender, + image_index, + ) + image_index += 1 + except Exception as e: + if _is_primary_output_node(shard_metadata): + event_sender.send( + ChunkGenerated( + command_id=command_id, + chunk=ErrorChunk( + model=shard_metadata.model_card.model_id, + finish_reason="error", + error_message=str(e), + ), + ) + ) + raise + finally: + _send_traces_if_enabled(event_sender, task.task_id, device_rank) + + current_status = RunnerReady() + logger.info("runner ready") + + case Shutdown(): + current_status = RunnerShuttingDown() + logger.info("runner shutting down") + if not TYPE_CHECKING: + del image_model, group + mx.clear_cache() + import gc + + gc.collect() + + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + current_status = RunnerShutdown() + case _: + raise ValueError( + f"Received {task.__class__.__name__} outside of state machine in {current_status=}" + ) + was_cancelled = (task.task_id in cancelled_tasks) or ( + TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks + ) + if not was_cancelled: + event_sender.send( + TaskStatusUpdated( + task_id=task.task_id, task_status=TaskStatus.Complete + ) + ) + event_sender.send( + RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) + ) + + if isinstance(current_status, RunnerShutdown): + break diff --git a/src/exo/worker/runner/llm_inference/__init__.py b/src/exo/worker/runner/llm_inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/llm_inference/runner.py similarity index 66% rename from src/exo/worker/runner/runner.py rename to src/exo/worker/runner/llm_inference/runner.py index 47293d8ce..5581f6dcc 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/llm_inference/runner.py @@ -1,10 +1,9 @@ -import base64 import math import resource import time from collections.abc import Generator from functools import cache -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, cast import mlx.core as mx from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model @@ -18,31 +17,22 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs] load_harmony_encoding, ) -from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED -from exo.shared.models.model_cards import ModelId, ModelTask -from exo.shared.tracing import clear_trace_buffer, get_trace_buffer -from exo.shared.types.api import ImageGenerationStats +from exo.shared.models.model_cards import ModelTask from exo.shared.types.chunks import ( ErrorChunk, - ImageChunk, PrefillProgressChunk, TokenChunk, ToolCallChunk, ) -from exo.shared.types.common import CommandId from exo.shared.types.events import ( ChunkGenerated, Event, RunnerStatusUpdated, TaskAcknowledged, TaskStatusUpdated, - TraceEventData, - TracesCollected, ) from exo.shared.types.tasks import ( ConnectToGroup, - ImageEdits, - ImageGeneration, LoadModel, Shutdown, StartWarmup, @@ -55,8 +45,6 @@ from exo.shared.types.text_generation import TextGenerationTaskParams from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.runner_response import ( GenerationResponse, - ImageGenerationResponse, - PartialImageResponse, ToolCallItem, ToolCallResponse, ) @@ -74,18 +62,7 @@ from exo.shared.types.worker.runners import ( RunnerStatus, RunnerWarmingUp, ) -from exo.shared.types.worker.shards import ( - CfgShardMetadata, - PipelineShardMetadata, - ShardMetadata, -) from exo.utils.channels import MpReceiver, MpSender -from exo.worker.engines.image import ( - DistributedImageModel, - generate_image, - initialize_image_model, - warmup_image_generator, -) from exo.worker.engines.mlx import Model from exo.worker.engines.mlx.cache import KVPrefixCache from exo.worker.engines.mlx.generator.generate import ( @@ -106,22 +83,6 @@ from exo.worker.runner.bootstrap import logger from .tool_parsers import ToolParser, make_mlx_parser -def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool: - """Check if this node is the primary output node for image generation. - - For CFG models: the last pipeline stage in CFG group 0 (positive prompt). - For non-CFG models: the last pipeline stage. - """ - if isinstance(shard_metadata, CfgShardMetadata): - is_pipeline_last = ( - shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1 - ) - return is_pipeline_last and shard_metadata.cfg_rank == 0 - elif isinstance(shard_metadata, PipelineShardMetadata): - return shard_metadata.device_rank == shard_metadata.world_size - 1 - return False - - def main( bound_instance: BoundInstance, event_sender: MpSender[Event], @@ -146,9 +107,7 @@ def main( setup_start_time = time.time() cancelled_tasks = set[TaskId]() - # type checker was unhappy with me - splitting these fixed it inference_model: Model | None = None - image_model: DistributedImageModel | None = None tokenizer = None tool_parser: ToolParser | None = None group = None @@ -211,33 +170,25 @@ def main( ) time.sleep(0.5) - if ModelTask.TextGeneration in shard_metadata.model_card.tasks: - inference_model, tokenizer = load_mlx_items( - bound_instance, group, on_timeout=on_model_load_timeout - ) - logger.info( - f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}" - ) - if tokenizer.has_tool_calling: - assert tokenizer.tool_call_start - assert tokenizer.tool_call_end - assert tokenizer.tool_parser # pyright: ignore[reportAny] - tool_parser = make_mlx_parser( - tokenizer.tool_call_start, - tokenizer.tool_call_end, - tokenizer.tool_parser, # pyright: ignore[reportAny] - ) - kv_prefix_cache = KVPrefixCache(group) - - elif ( - ModelTask.TextToImage in shard_metadata.model_card.tasks - or ModelTask.ImageToImage in shard_metadata.model_card.tasks - ): - image_model = initialize_image_model(bound_instance) - else: - raise ValueError( - f"Unknown model task(s): {shard_metadata.model_card.tasks}" + assert ( + ModelTask.TextGeneration in shard_metadata.model_card.tasks + ), f"Incorrect model task(s): {shard_metadata.model_card.tasks}" + inference_model, tokenizer = load_mlx_items( + bound_instance, group, on_timeout=on_model_load_timeout + ) + logger.info( + f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}" + ) + if tokenizer.has_tool_calling: + assert tokenizer.tool_call_start + assert tokenizer.tool_call_end + assert tokenizer.tool_parser # pyright: ignore[reportAny] + tool_parser = make_mlx_parser( + tokenizer.tool_call_start, + tokenizer.tool_call_end, + tokenizer.tool_parser, # pyright: ignore[reportAny] ) + kv_prefix_cache = KVPrefixCache(group) current_status = RunnerLoaded() logger.info("runner loaded") case StartWarmup() if isinstance(current_status, RunnerLoaded): @@ -251,46 +202,34 @@ def main( event_sender.send(TaskAcknowledged(task_id=task.task_id)) logger.info(f"warming up inference for instance: {instance}") - if ModelTask.TextGeneration in shard_metadata.model_card.tasks: - assert inference_model - assert tokenizer + assert inference_model + assert tokenizer - t = time.monotonic() - toks = warmup_inference( - model=inference_model, - tokenizer=tokenizer, - group=group, + t = time.monotonic() + toks = warmup_inference( + model=cast(Model, inference_model), + tokenizer=tokenizer, + group=group, + ) + logger.info(f"warmed up by generating {toks} tokens") + check_for_cancel_every = min( + math.ceil(toks / min(time.monotonic() - t, 0.001)), 100 + ) + if group is not None: + check_for_cancel_every = int( + mx.max( + mx.distributed.all_gather( + mx.array([check_for_cancel_every]), group=group + ) + ).item() ) - logger.info(f"warmed up by generating {toks} tokens") - check_for_cancel_every = min( - math.ceil(toks / min(time.monotonic() - t, 0.001)), 100 - ) - if group is not None: - check_for_cancel_every = int( - mx.max( - mx.distributed.all_gather( - mx.array([check_for_cancel_every]), group=group - ) - ).item() - ) - - logger.info( - f"runner checking for cancellation every {check_for_cancel_every} tokens" - ) - logger.info( - f"runner initialized in {time.time() - setup_start_time} seconds" - ) - elif ( - ModelTask.TextToImage in shard_metadata.model_card.tasks - or ModelTask.ImageToImage in shard_metadata.model_card.tasks - ): - assert image_model - image = warmup_image_generator(model=image_model) - if image is not None: - logger.info(f"warmed up by generating {image.size} image") - else: - logger.info("warmup completed (non-primary node)") + logger.info( + f"runner checking for cancellation every {check_for_cancel_every} tokens" + ) + logger.info( + f"runner initialized in {time.time() - setup_start_time} seconds" + ) current_status = RunnerReady() logger.info("runner ready") case TextGeneration(task_params=task_params, command_id=command_id) if ( @@ -345,7 +284,7 @@ def main( # Generate responses using the actual MLX generation mlx_generator = mlx_generate( - model=inference_model, + model=cast(Model, inference_model), tokenizer=tokenizer, task=task_params, prompt=prompt, @@ -458,138 +397,12 @@ def main( current_status = RunnerReady() logger.info("runner ready") - case ImageGeneration( - task_params=task_params, command_id=command_id - ) if isinstance(current_status, RunnerReady): - assert image_model - logger.info(f"received image generation request: {str(task)[:500]}") - current_status = RunnerRunning() - logger.info("runner running") - event_sender.send( - RunnerStatusUpdated( - runner_id=runner_id, runner_status=current_status - ) - ) - event_sender.send(TaskAcknowledged(task_id=task.task_id)) - try: - image_index = 0 - for response in generate_image( - model=image_model, task=task_params - ): - is_primary_output = _is_primary_output_node(shard_metadata) - - if is_primary_output: - match response: - case PartialImageResponse(): - logger.info( - f"sending partial ImageChunk {response.partial_index}/{response.total_partials}" - ) - _process_image_response( - response, - command_id, - shard_metadata, - event_sender, - image_index, - ) - case ImageGenerationResponse(): - logger.info("sending final ImageChunk") - _process_image_response( - response, - command_id, - shard_metadata, - event_sender, - image_index, - ) - image_index += 1 - # can we make this more explicit? - except Exception as e: - if _is_primary_output_node(shard_metadata): - event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=ErrorChunk( - model=shard_metadata.model_card.model_id, - finish_reason="error", - error_message=str(e), - ), - ) - ) - raise - finally: - _send_traces_if_enabled( - event_sender, task.task_id, shard_metadata.device_rank - ) - - current_status = RunnerReady() - logger.info("runner ready") - case ImageEdits(task_params=task_params, command_id=command_id) if ( - isinstance(current_status, RunnerReady) - ): - assert image_model - logger.info(f"received image edits request: {str(task)[:500]}") - current_status = RunnerRunning() - logger.info("runner running") - event_sender.send( - RunnerStatusUpdated( - runner_id=runner_id, runner_status=current_status - ) - ) - event_sender.send(TaskAcknowledged(task_id=task.task_id)) - - try: - image_index = 0 - for response in generate_image( - model=image_model, task=task_params - ): - if _is_primary_output_node(shard_metadata): - match response: - case PartialImageResponse(): - logger.info( - f"sending partial ImageChunk {response.partial_index}/{response.total_partials}" - ) - _process_image_response( - response, - command_id, - shard_metadata, - event_sender, - image_index, - ) - case ImageGenerationResponse(): - logger.info("sending final ImageChunk") - _process_image_response( - response, - command_id, - shard_metadata, - event_sender, - image_index, - ) - image_index += 1 - except Exception as e: - if _is_primary_output_node(shard_metadata): - event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=ErrorChunk( - model=shard_metadata.model_card.model_id, - finish_reason="error", - error_message=str(e), - ), - ) - ) - raise - finally: - _send_traces_if_enabled( - event_sender, task.task_id, shard_metadata.device_rank - ) - - current_status = RunnerReady() - logger.info("runner ready") case Shutdown(): current_status = RunnerShuttingDown() logger.info("runner shutting down") if not TYPE_CHECKING: - del inference_model, image_model, tokenizer, group + del inference_model, tokenizer, group mx.clear_cache() import gc @@ -890,104 +703,6 @@ def parse_thinking_models( yield response.model_copy(update={"is_thinking": in_thinking}) -def _send_image_chunk( - encoded_data: str, - command_id: CommandId, - model_id: ModelId, - event_sender: MpSender[Event], - image_index: int, - is_partial: bool, - partial_index: int | None = None, - total_partials: int | None = None, - stats: ImageGenerationStats | None = None, - image_format: Literal["png", "jpeg", "webp"] | None = None, -) -> None: - """Send base64-encoded image data as chunks via events.""" - data_chunks = [ - encoded_data[i : i + EXO_MAX_CHUNK_SIZE] - for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE) - ] - total_chunks = len(data_chunks) - for chunk_index, chunk_data in enumerate(data_chunks): - # Only include stats on the last chunk of the final image - chunk_stats = ( - stats if chunk_index == total_chunks - 1 and not is_partial else None - ) - event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=ImageChunk( - model=model_id, - data=chunk_data, - chunk_index=chunk_index, - total_chunks=total_chunks, - image_index=image_index, - is_partial=is_partial, - partial_index=partial_index, - total_partials=total_partials, - stats=chunk_stats, - format=image_format, - ), - ) - ) - - -def _send_traces_if_enabled( - event_sender: MpSender[Event], - task_id: TaskId, - rank: int, -) -> None: - if not EXO_TRACING_ENABLED: - return - - traces = get_trace_buffer() - if traces: - trace_data = [ - TraceEventData( - name=t.name, - start_us=t.start_us, - duration_us=t.duration_us, - rank=t.rank, - category=t.category, - ) - for t in traces - ] - event_sender.send( - TracesCollected( - task_id=task_id, - rank=rank, - traces=trace_data, - ) - ) - clear_trace_buffer() - - -def _process_image_response( - response: ImageGenerationResponse | PartialImageResponse, - command_id: CommandId, - shard_metadata: ShardMetadata, - event_sender: MpSender[Event], - image_index: int, -) -> None: - """Process a single image response and send chunks.""" - encoded_data = base64.b64encode(response.image_data).decode("utf-8") - is_partial = isinstance(response, PartialImageResponse) - # Extract stats from final ImageGenerationResponse if available - stats = response.stats if isinstance(response, ImageGenerationResponse) else None - _send_image_chunk( - encoded_data=encoded_data, - command_id=command_id, - model_id=shard_metadata.model_card.model_id, - event_sender=event_sender, - image_index=response.image_index, - is_partial=is_partial, - partial_index=response.partial_index if is_partial else None, - total_partials=response.total_partials if is_partial else None, - stats=stats, - image_format=response.format, - ) - - def parse_tool_calls( responses: Generator[GenerationResponse], tool_parser: ToolParser ) -> Generator[GenerationResponse | ToolCallResponse]: diff --git a/src/exo/worker/runner/tool_parsers.py b/src/exo/worker/runner/llm_inference/tool_parsers.py similarity index 100% rename from src/exo/worker/runner/tool_parsers.py rename to src/exo/worker/runner/llm_inference/tool_parsers.py diff --git a/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py b/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py index a59383e5d..63b3587fd 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py +++ b/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py @@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import ( encode_messages, parse_dsml_output, ) -from exo.worker.runner.runner import parse_deepseek_v32 +from exo.worker.runner.llm_inference.runner import parse_deepseek_v32 # ── Shared fixtures ────────────────────────────────────────────── diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py index 864e424bd..6d964583e 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -6,7 +6,7 @@ from typing import Callable import mlx.core as mx import pytest -import exo.worker.runner.runner as mlx_runner +import exo.worker.runner.llm_inference.runner as mlx_runner from exo.shared.types.chunks import TokenChunk from exo.shared.types.events import ( ChunkGenerated, @@ -180,7 +180,7 @@ def _run(tasks: Iterable[Task]): task_receiver.close = nothin task_receiver.join = nothin with unittest.mock.patch( - "exo.worker.runner.runner.mx.distributed.all_gather", + "exo.worker.runner.llm_inference.runner.mx.distributed.all_gather", make_nothin(mx.array([1])), ): mlx_runner.main( diff --git a/src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py b/src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py index 080f03890..6302e4a9d 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py +++ b/src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py @@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import ( GenerationResponse, ToolCallResponse, ) -from exo.worker.runner.runner import parse_gpt_oss +from exo.worker.runner.llm_inference.runner import parse_gpt_oss # Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer. # These are stable since they come from the model's vocabulary. diff --git a/src/exo/worker/tests/unittests/test_runner/test_parse_tool_calls.py b/src/exo/worker/tests/unittests/test_runner/test_parse_tool_calls.py index 8a23a18c1..32d331f6d 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_parse_tool_calls.py +++ b/src/exo/worker/tests/unittests/test_runner/test_parse_tool_calls.py @@ -4,8 +4,8 @@ from collections.abc import Generator from typing import Any from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse -from exo.worker.runner.runner import parse_tool_calls -from exo.worker.runner.tool_parsers import make_mlx_parser +from exo.worker.runner.llm_inference.runner import parse_tool_calls +from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser def _make_responses(