Refactor runner into separate runners (#1570)

## Motivation

We're going to be refactoring the llm inference code, so we should split
the runner up into parts while we can.

## Test Plan

### Manual Testing
Works on single node, at least.

### Automated Testing
Passes CI. Will be tested by our tests today.
This commit is contained in:
rltakashige
2026-02-20 17:11:01 +00:00
committed by GitHub
parent dc8d42b4dc
commit 1a2b8b044a
12 changed files with 519 additions and 339 deletions

View File

@@ -2,6 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -49,6 +50,13 @@ class BoundInstance(CamelCaseModel):
assert shard is not None
return shard
@property
def is_image_model(self) -> bool:
return (
ModelTask.TextToImage in self.bound_shard.model_card.tasks
or ModelTask.ImageToImage in self.bound_shard.model_card.tasks
)
@model_validator(mode="after")
def validate_shard_exists(self) -> "BoundInstance":
assert (

View File

View File

@@ -37,9 +37,13 @@ def entrypoint(
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main
if bound_instance.is_image_model:
from exo.worker.runner.image_models.runner import main
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

View File

@@ -0,0 +1,453 @@
import base64
import resource
import time
from typing import TYPE_CHECKING, Literal
import mlx.core as mx
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelTask
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
from exo.shared.types.api import ImageGenerationStats
from exo.shared.types.chunks import ErrorChunk, ImageChunk
from exo.shared.types.common import CommandId, ModelId
from exo.shared.types.events import (
ChunkGenerated,
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
)
from exo.shared.types.tasks import (
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
ShardMetadata,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
)
from exo.worker.runner.bootstrap import logger
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
For CFG models: the last pipeline stage in CFG group 0 (positive prompt).
For non-CFG models: the last pipeline stage.
"""
if isinstance(shard_metadata, CfgShardMetadata):
is_pipeline_last = (
shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1
)
return is_pipeline_last and shard_metadata.cfg_rank == 0
elif isinstance(shard_metadata, PipelineShardMetadata):
return shard_metadata.device_rank == shard_metadata.world_size - 1
return False
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
image_index: int,
) -> None:
"""Process a single image response and send chunks."""
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
is_partial = isinstance(response, PartialImageResponse)
# Extract stats from final ImageGenerationResponse if available
stats = response.stats if isinstance(response, ImageGenerationResponse) else None
_send_image_chunk(
encoded_data=encoded_data,
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
stats=stats,
image_format=response.format,
)
def _send_traces_if_enabled(
event_sender: MpSender[Event],
task_id: TaskId,
rank: int,
) -> None:
if not EXO_TRACING_ENABLED:
return
traces = get_trace_buffer()
if traces:
trace_data = [
TraceEventData(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
for t in traces
]
event_sender.send(
TracesCollected(
task_id=task_id,
rank=rank,
traces=trace_data,
)
)
clear_trace_buffer()
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,
model_id: ModelId,
event_sender: MpSender[Event],
image_index: int,
is_partial: bool,
partial_index: int | None = None,
total_partials: int | None = None,
stats: ImageGenerationStats | None = None,
image_format: Literal["png", "jpeg", "webp"] | None = None,
) -> None:
"""Send base64-encoded image data as chunks via events."""
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
for chunk_index, chunk_data in enumerate(data_chunks):
# Only include stats on the last chunk of the final image
chunk_stats = (
stats if chunk_index == total_chunks - 1 and not is_partial else None
)
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=is_partial,
partial_index=partial_index,
total_partials=total_partials,
stats=chunk_stats,
format=image_format,
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(shard_metadata, "should_timeout", 0):
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
image_model: DistributedImageModel | None = None
group = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
seen = set[TaskId]()
with task_receiver as tasks:
for task in tasks:
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
), f"Incorrect model task(s): {shard_metadata.model_card.tasks}"
image_model = initialize_image_model(bound_instance)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
logger.info(f"warming up inference for instance: {instance}")
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
if not TYPE_CHECKING:
del image_model, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
was_cancelled = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if not was_cancelled:
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break

View File

View File

@@ -1,10 +1,9 @@
import base64
import math
import resource
import time
from collections.abc import Generator
from functools import cache
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, cast
import mlx.core as mx
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
@@ -18,31 +17,22 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
load_harmony_encoding,
)
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
from exo.shared.types.api import ImageGenerationStats
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
)
from exo.shared.types.tasks import (
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -55,8 +45,6 @@ from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
ToolCallItem,
ToolCallResponse,
)
@@ -74,18 +62,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
ShardMetadata,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import (
@@ -106,22 +83,6 @@ from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser, make_mlx_parser
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
For CFG models: the last pipeline stage in CFG group 0 (positive prompt).
For non-CFG models: the last pipeline stage.
"""
if isinstance(shard_metadata, CfgShardMetadata):
is_pipeline_last = (
shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1
)
return is_pipeline_last and shard_metadata.cfg_rank == 0
elif isinstance(shard_metadata, PipelineShardMetadata):
return shard_metadata.device_rank == shard_metadata.world_size - 1
return False
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -146,9 +107,7 @@ def main(
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
tool_parser: ToolParser | None = None
group = None
@@ -211,33 +170,25 @@ def main(
)
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
image_model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
assert (
ModelTask.TextGeneration in shard_metadata.model_card.tasks
), f"Incorrect model task(s): {shard_metadata.model_card.tasks}"
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -251,46 +202,34 @@ def main(
event_sender.send(TaskAcknowledged(task_id=task.task_id))
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert inference_model
assert tokenizer
assert inference_model
assert tokenizer
t = time.monotonic()
toks = warmup_inference(
model=inference_model,
tokenizer=tokenizer,
group=group,
t = time.monotonic()
toks = warmup_inference(
model=cast(Model, inference_model),
tokenizer=tokenizer,
group=group,
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
if group is not None:
check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]), group=group
)
).item()
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
if group is not None:
check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]), group=group
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case TextGeneration(task_params=task_params, command_id=command_id) if (
@@ -345,7 +284,7 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=inference_model,
model=cast(Model, inference_model),
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
@@ -458,138 +397,12 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(
event_sender, task.task_id, shard_metadata.device_rank
)
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(
event_sender, task.task_id, shard_metadata.device_rank
)
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
if not TYPE_CHECKING:
del inference_model, image_model, tokenizer, group
del inference_model, tokenizer, group
mx.clear_cache()
import gc
@@ -890,104 +703,6 @@ def parse_thinking_models(
yield response.model_copy(update={"is_thinking": in_thinking})
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,
model_id: ModelId,
event_sender: MpSender[Event],
image_index: int,
is_partial: bool,
partial_index: int | None = None,
total_partials: int | None = None,
stats: ImageGenerationStats | None = None,
image_format: Literal["png", "jpeg", "webp"] | None = None,
) -> None:
"""Send base64-encoded image data as chunks via events."""
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
for chunk_index, chunk_data in enumerate(data_chunks):
# Only include stats on the last chunk of the final image
chunk_stats = (
stats if chunk_index == total_chunks - 1 and not is_partial else None
)
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=is_partial,
partial_index=partial_index,
total_partials=total_partials,
stats=chunk_stats,
format=image_format,
),
)
)
def _send_traces_if_enabled(
event_sender: MpSender[Event],
task_id: TaskId,
rank: int,
) -> None:
if not EXO_TRACING_ENABLED:
return
traces = get_trace_buffer()
if traces:
trace_data = [
TraceEventData(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
for t in traces
]
event_sender.send(
TracesCollected(
task_id=task_id,
rank=rank,
traces=trace_data,
)
)
clear_trace_buffer()
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
image_index: int,
) -> None:
"""Process a single image response and send chunks."""
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
is_partial = isinstance(response, PartialImageResponse)
# Extract stats from final ImageGenerationResponse if available
stats = response.stats if isinstance(response, ImageGenerationResponse) else None
_send_image_chunk(
encoded_data=encoded_data,
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
stats=stats,
image_format=response.format,
)
def parse_tool_calls(
responses: Generator[GenerationResponse], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse]:

View File

@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
encode_messages,
parse_dsml_output,
)
from exo.worker.runner.runner import parse_deepseek_v32
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32
# ── Shared fixtures ──────────────────────────────────────────────

View File

@@ -6,7 +6,7 @@ from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
import exo.worker.runner.llm_inference.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
@@ -180,7 +180,7 @@ def _run(tasks: Iterable[Task]):
task_receiver.close = nothin
task_receiver.join = nothin
with unittest.mock.patch(
"exo.worker.runner.runner.mx.distributed.all_gather",
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(

View File

@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.runner.runner import parse_gpt_oss
from exo.worker.runner.llm_inference.runner import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary.

View File

@@ -4,8 +4,8 @@ from collections.abc import Generator
from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
from exo.worker.runner.tool_parsers import make_mlx_parser
from exo.worker.runner.llm_inference.runner import parse_tool_calls
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
def _make_responses(