Compare commits

..

5 Commits

Author SHA1 Message Date
Evan
7c888e6612 api cancellation
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks
2026-01-30 13:56:49 +00:00
Evan Quiney
46a14153dd switch to ModelCard.load outside of download log (#1339)
some attempts to load model cards (i.e. build_base_shard) always went
through networking rather than using downloaded model cards. we should
always default to ModelCard.load in these scenarios
2026-01-30 11:20:20 +00:00
Evan
9ba61f3733 improve log message in shard downloader
closes #1336
2026-01-30 10:35:01 +00:00
rltakashige
d9eca75895 Add usage stats (#1333)
## Motivation

(Probably) the final missing piece of the Chat Completions API 

## Changes

Add UsageStats 

## Why It Works

OpenCode reviewed my PR and gave me stats:

<img width="1150" height="802" alt="image"
src="https://github.com/user-attachments/assets/ebc06bae-797f-4087-87d5-2f26cf60fc48"
/>


## Test Plan

### Automated Testing
No tests were broken.
2026-01-30 10:23:08 +00:00
rltakashige
9dabde7e57 Fix bench after recent updates (#1331)
## Motivation

A lot of changes happened without much attention to the state of exo
bench.

## Changes

Use TaggedModel for BenchChatCompletion so it serialises properly.
Don't break after gpt oss tool call to preserve parity with the rest of
the codebase.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<img width="2856" height="678" alt="image"
src="https://github.com/user-attachments/assets/2e18cf0d-c0f8-467c-9763-1a6a59c8a327"
/>

Also tested GPT OSS tool calling in OpenCode
2026-01-29 19:14:40 +00:00
30 changed files with 571 additions and 907 deletions

View File

@@ -5,18 +5,18 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)

View File

@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
model_card = await ModelCard.load(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -166,9 +166,8 @@ class ResumableShardDownloader(ShardDownloader):
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.error("Error downloading shard:", e)
logger.warning(f"Error downloading shard: {type(e).__name__}")
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -1,7 +1,6 @@
import base64
import contextlib
import json
import random
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
@@ -66,7 +65,9 @@ from exo.shared.types.api import (
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
StreamOptions,
ToolCall,
Usage,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -89,6 +90,7 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -113,17 +115,10 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
"""Ensure advanced params has a seed set for distributed consistency."""
if params is None:
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
if params.seed is None:
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
return params
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk,
command_id: CommandId,
usage: Usage | None,
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -148,21 +143,10 @@ def chunk_to_response(
finish_reason=chunk.finish_reason,
)
],
usage=usage,
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -284,7 +268,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -301,7 +285,7 @@ class API:
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
model_card = await ModelCard.load(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
@@ -329,7 +313,7 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_card = await ModelCard.load(model_id)
try:
placements = get_instance_placements(
@@ -518,23 +502,22 @@ class API:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, stream_options: StreamOptions | None = None
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
include_usage = stream_options.include_usage if stream_options else False
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
@@ -550,8 +533,10 @@ class API:
yield "data: [DONE]\n\n"
return
usage = chunk.usage if include_usage else None
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
chunk, command_id, usage=usage
)
logger.debug(f"chunk_response: {chunk_response}")
@@ -562,13 +547,14 @@ class API:
async def _collect_chat_completion(
self, command_id: CommandId
) -> ChatCompletionResponse:
) -> AsyncGenerator[ChatCompletionResponse]:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
usage: Usage | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ErrorChunk):
@@ -593,13 +579,16 @@ class API:
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.usage is not None:
usage = chunk.usage
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
yield ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
@@ -614,14 +603,16 @@ class API:
finish_reason=finish_reason,
)
],
usage=usage,
)
return
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
@@ -674,7 +665,7 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
@@ -683,7 +674,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -701,16 +692,19 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, payload.stream_options),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return StreamingResponse(
self._collect_chat_completion(command.command_id),
media_type="application/json",
)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -730,12 +724,12 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(ModelId(model))
model_card = await ModelCard.load(model)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
@@ -781,10 +775,7 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
payload.model = await self._validate_image_model(ModelId(payload.model))
command = ImageGeneration(
request_params=payload,
@@ -914,6 +905,11 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -995,6 +991,11 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -1029,13 +1030,10 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
request_params=payload,
@@ -1053,7 +1051,7 @@ class API:
self,
image: UploadFile,
prompt: str,
model: str,
model: ModelId,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1067,7 +1065,6 @@ class API:
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
advanced_params = _ensure_seed(advanced_params)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
@@ -1149,7 +1146,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1205,7 +1202,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,

View File

@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
)
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -246,7 +248,7 @@ class Master:
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case PlaceInstance():
@@ -258,7 +260,7 @@ class Master:
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case CreateInstance():
@@ -268,7 +270,7 @@ class Master:
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
@@ -278,6 +280,18 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -286,10 +300,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -20,9 +20,15 @@ from exo.shared.types.commands import (
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -180,6 +186,7 @@ def delete_instance(
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
@@ -195,6 +202,18 @@ def get_transition_events(
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,

View File

@@ -94,35 +94,20 @@ def get_shard_assignments_for_pipeline_parallel(
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Determine CFG parallelism topology
# CFG parallel only for even node counts with CFG models (2+ nodes)
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
cfg_world_size = 2 if use_cfg_parallel else 1
pipeline_world_size = world_size // cfg_world_size
# For CFG parallel, we only need to allocate layers for one pipeline group
# (both CFG groups run the same layers). Use the first pipeline group's nodes.
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
pipeline_memory = sum(
(node_memory[node_id].ram_available for node_id in pipeline_node_ids),
start=Memory(),
)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / pipeline_memory.in_bytes
for node_id in pipeline_node_ids
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each pipeline node has sufficient memory for its assigned layers
# Use integer arithmetic to avoid floating point precision issues
total_storage_bytes = model_card.storage_size.in_bytes
for i, node_id in enumerate(pipeline_node_ids):
node_layers = layer_allocations[i]
# Integer division then multiply to get conservative estimate
required_memory = (total_storage_bytes * node_layers) // total_layers
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
@@ -131,69 +116,24 @@ def get_shard_assignments_for_pipeline_parallel(
f"but only has {available_memory / (1024**3):.2f} GB available"
)
# CFG group 0: pipeline ranks in ascending order (0, 1, 2, ...)
# CFG group 1: pipeline ranks in descending order (reversed)
# This places both "last stages" as ring neighbors for CFG exchange.
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
(1, r) for r in reversed(range(pipeline_world_size))
]
cfg_pipeline_to_device: dict[tuple[int, int], int] = {
(cfg_rank, pipeline_rank): i
for i, (cfg_rank, pipeline_rank) in enumerate(position_to_cfg_pipeline)
}
for i, node_id in enumerate(cycle.node_ids):
cfg_rank, pipeline_rank = position_to_cfg_pipeline[i]
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
is_first_stage = pipeline_rank == 0
is_last_stage = pipeline_rank == pipeline_world_size - 1
if is_last_stage:
next_pipeline_device = None
else:
next_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank + 1)]
if is_first_stage:
prev_pipeline_device = None
else:
prev_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank - 1)]
if is_last_stage and use_cfg_parallel:
other_cfg_rank = 1 - cfg_rank
cfg_peer_device = cfg_pipeline_to_device[(other_cfg_rank, pipeline_rank)]
else:
cfg_peer_device = None
first_pipeline_device = cfg_pipeline_to_device[(cfg_rank, 0)]
last_pipeline_device = cfg_pipeline_to_device[
(cfg_rank, pipeline_world_size - 1)
]
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_before,
end_layer=layers_before + node_layers,
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
cfg_rank=cfg_rank,
cfg_world_size=cfg_world_size,
explicit_pipeline_rank=pipeline_rank,
next_pipeline_device=next_pipeline_device,
prev_pipeline_device=prev_pipeline_device,
cfg_peer_device=cfg_peer_device,
first_pipeline_device=first_pipeline_device,
last_pipeline_device=last_pipeline_device,
)
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,

View File

@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
target_instances = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 0
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1

View File

@@ -5,7 +5,6 @@ from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
@@ -21,7 +20,7 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@@ -488,195 +487,3 @@ def test_get_shard_assignments_insufficient_memory_raises():
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)
class TestCfgParallelPlacement:
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for i, node_id in enumerate(node_ids):
next_node = node_ids[(i + 1) % len(node_ids)]
conn = Connection(
source=node_id,
sink=next_node,
edge=create_socket_connection(i + 1),
)
topology.add_connection(conn)
return topology
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
# Both nodes should have all layers (no pipeline split)
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
assert shard.start_layer == 0
assert shard.end_layer == 60
assert shard.cfg_world_size == 2
cfg_ranks = sorted(
s.cfg_rank for s in shards if isinstance(s, PipelineShardMetadata)
)
assert cfg_ranks == [0, 1]
def test_four_nodes_cfg_model_uses_hybrid(self):
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
nodes = [NodeId() for _ in range(4)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 4]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 4
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
assert shard.cfg_world_size == 2
assert shard.pipeline_world_size == 2
# Check we have 2 nodes in each CFG group
cfg_0_shards = [
s
for s in shards
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 0
]
cfg_1_shards = [
s
for s in shards
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 1
]
assert len(cfg_0_shards) == 2
assert len(cfg_1_shards) == 2
# Both CFG groups should have the same layer assignments
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
"""Three nodes (odd) with CFG model should use sequential CFG."""
nodes = [NodeId() for _ in range(3)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 3]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 3
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# cfg_world_size = 1 means sequential CFG
assert shard.cfg_world_size == 1
assert shard.cfg_rank == 0
def test_two_nodes_non_cfg_model_uses_pipeline(self):
"""Two nodes with non-CFG model should use pure pipeline."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("flux-test"),
n_layers=57,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=False, # Non-CFG model
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# cfg_world_size = 1 means no CFG parallel
assert shard.cfg_world_size == 1
assert shard.cfg_rank == 0
# Should have actual layer sharding (pipeline)
layer_ranges = sorted(
(s.start_layer, s.end_layer)
for s in shards
if isinstance(s, PipelineShardMetadata)
)
# First shard starts at 0, last shard ends at 57
assert layer_ranges[0][0] == 0
assert layer_ranges[-1][1] == 57

View File

@@ -47,7 +47,6 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
uses_cfg: bool = False
@field_validator("tasks", mode="before")
@classmethod
@@ -563,7 +562,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
uses_cfg=True,
components=[
ComponentInfo(
component_name="text_encoder",
@@ -598,7 +596,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
uses_cfg=True,
components=[
ComponentInfo(
component_name="text_encoder",
@@ -684,7 +681,6 @@ def _generate_image_model_quant_variants(
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
uses_cfg=base_card.uses_cfg,
components=with_transformer_size(transformer_bytes),
)
}
@@ -704,7 +700,6 @@ def _generate_image_model_quant_variants(
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
uses_cfg=base_card.uses_cfg,
components=with_transformer_size(quant_transformer_bytes),
)

View File

@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -116,8 +116,8 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: PromptTokensDetails | None = None
completion_tokens_details: CompletionTokensDetails | None = None
prompt_tokens_details: PromptTokensDetails
completion_tokens_details: CompletionTokensDetails
class StreamingChoiceResponse(BaseModel):
@@ -170,7 +170,13 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
class StreamOptions(BaseModel):
include_usage: bool = False
class ChatCompletionTaskParams(TaggedModel):
model_config = ConfigDict(extra="ignore")
model: str
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
@@ -184,6 +190,7 @@ class ChatCompletionTaskParams(BaseModel):
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
stream_options: StreamOptions | None = None
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -17,6 +17,7 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
usage: Usage | None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
@@ -28,6 +29,7 @@ class ErrorChunk(BaseChunk):
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
usage: Usage | None
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None

View File

@@ -2,6 +2,7 @@ from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -22,7 +23,7 @@ class TestCommand(BaseCommand):
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
class ImageGeneration(BaseCommand):
@@ -48,6 +49,10 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -84,6 +89,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -3,6 +3,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -24,6 +25,7 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -54,12 +56,16 @@ class StartWarmup(BaseTask): # emitted by Worker
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: ChatCompletionTaskParams
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -87,6 +93,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -6,6 +6,7 @@ from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
ToolCallItem,
Usage,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -24,6 +25,7 @@ class GenerationResponse(BaseRunnerResponse):
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
usage: Usage | None
class ImageGenerationResponse(BaseRunnerResponse):
@@ -57,6 +59,7 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -57,62 +57,8 @@ class PipelineShardMetadata(BaseShardMetadata):
Layers are represented as a half-open interval [start_layer, end_layer),
where start_layer is inclusive and end_layer is exclusive.
CFG parallelism fields:
- cfg_rank: 0 = positive branch, 1 = negative branch (or 0 if no CFG parallel)
- cfg_world_size: 1 = sequential CFG, 2 = parallel CFG
Communication rank fields (explicit to support ring topology):
- next_pipeline_device: device to send to in pipeline forward pass
- prev_pipeline_device: device to receive from in pipeline forward pass
- cfg_peer_device: device for CFG exchange (last stage only)
- first_pipeline_device: device of first stage in same CFG group (for latent return)
"""
cfg_rank: int = 0
cfg_world_size: int = 1
# Explicit pipeline position (CFG group 1 uses reversed pipeline order)
explicit_pipeline_rank: int | None = None
next_pipeline_device: int | None = None
prev_pipeline_device: int | None = None
cfg_peer_device: int | None = None
first_pipeline_device: int | None = None
last_pipeline_device: int | None = None
@property
def pipeline_world_size(self) -> int:
return self.world_size // self.cfg_world_size
@property
def pipeline_rank(self) -> int:
if self.explicit_pipeline_rank is not None:
return self.explicit_pipeline_rank
return self.device_rank % self.pipeline_world_size
@property
def is_pipeline_first(self) -> bool:
return self.pipeline_rank == 0
@property
def is_pipeline_last(self) -> bool:
return self.pipeline_rank == self.pipeline_world_size - 1
def __hash__(self) -> int:
return hash(
(
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,
self.device_rank,
self.world_size,
self.cfg_rank,
self.cfg_world_size,
)
)
class TensorShardMetadata(BaseShardMetadata):
pass

View File

@@ -37,12 +37,7 @@ class DistributedImageModel:
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
has_layer_sharding = (
shard_metadata.start_layer != 0
or shard_metadata.end_layer != shard_metadata.n_layers
)
if group is not None and has_layer_sharding:
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,

View File

@@ -86,27 +86,6 @@ class PromptData(ABC):
"""
...
@abstractmethod
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Get embeddings for a single CFG branch (positive or negative).
Used for sequential CFG and CFG parallel modes where we process
one branch at a time instead of batching.
Args:
positive: True for positive prompt, False for negative prompt
Returns:
Tuple of:
- embeds: [1, seq, hidden] prompt embeddings
- mask: [1, seq] attention mask or None
- pooled: [1, hidden] pooled embeddings or None
- conditioning_latents: [1, latent_seq, latent_dim] or None
"""
...
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
_config: ImageModelConfig

View File

@@ -64,12 +64,6 @@ class FluxPromptData(PromptData):
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
return None
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Flux doesn't use CFG, but we return positive data for compatibility."""
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
def __init__(

View File

@@ -133,24 +133,6 @@ class QwenPromptData(PromptData):
return batched_embeds, batched_mask, None, cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self.conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self.conditioning_latents,
)
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
"""Adapter for Qwen-Image model.

View File

@@ -153,24 +153,6 @@ class QwenEditPromptData(PromptData):
return batched_embeds, batched_mask, None, batched_cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self._conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self._conditioning_latents,
)
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
"""Adapter for Qwen-Image-Edit model.

View File

@@ -1,7 +1,5 @@
from collections.abc import Iterator
from dataclasses import dataclass
from math import ceil
from typing import Any, Optional, final
from typing import Any, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
@@ -22,16 +20,6 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
)
@final
@dataclass
class CfgBranch:
positive: bool
embeds: mx.array
mask: mx.array | None
pooled: mx.array | None
cond_latents: mx.array | None
def calculate_patch_heights(
latent_height: int, num_patches: int
) -> tuple[list[int], int]:
@@ -84,11 +72,22 @@ class DiffusionRunner:
self.adapter = adapter
self.group = group
self._init_cfg_topology(shard_metadata)
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_patches = (
num_patches if num_patches else max(1, self.pipeline_world_size)
)
self.num_patches = num_patches if num_patches else max(1, self.world_size)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
@@ -98,48 +97,6 @@ class DiffusionRunner:
self._compute_assigned_blocks()
def _init_cfg_topology(self, shard_metadata: PipelineShardMetadata) -> None:
"""Initialize CFG and pipeline topology from shard metadata."""
if self.group is None:
self.rank = 0
self.world_size = 1
self.start_layer = 0
self.end_layer = self.config.total_blocks
self.cfg_rank = 0
self.cfg_world_size = 1
self.cfg_parallel = False
self.pipeline_world_size = 1
self.pipeline_rank = 0
self.next_pipeline_rank: int | None = None
self.prev_pipeline_rank: int | None = None
self.cfg_peer_rank: int | None = None
self.first_pipeline_rank: int = 0
self.last_pipeline_rank: int = 0
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.cfg_rank = shard_metadata.cfg_rank
self.cfg_world_size = shard_metadata.cfg_world_size
self.cfg_parallel = self.cfg_world_size > 1
self.pipeline_world_size = shard_metadata.pipeline_world_size
self.pipeline_rank = shard_metadata.pipeline_rank
self.next_pipeline_rank = shard_metadata.next_pipeline_device
self.prev_pipeline_rank = shard_metadata.prev_pipeline_device
self.cfg_peer_rank = shard_metadata.cfg_peer_device
assert shard_metadata.first_pipeline_device is not None
assert shard_metadata.last_pipeline_device is not None
self.first_pipeline_rank = shard_metadata.first_pipeline_device
self.last_pipeline_rank = shard_metadata.last_pipeline_device
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
@@ -176,11 +133,11 @@ class DiffusionRunner:
@property
def is_first_stage(self) -> bool:
return self.pipeline_rank == 0
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.pipeline_rank == self.pipeline_world_size - 1
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
@@ -191,97 +148,6 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
"""Yield the CFG branches this node should process.
- No CFG: yields one branch (positive)
- CFG parallel: yields one branch (our assigned branch)
- Sequential CFG: yields two branches (positive, then negative)
"""
if not self.adapter.needs_cfg:
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
yield CfgBranch(
positive=True,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
elif self.cfg_parallel:
positive = self.cfg_rank == 0
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
yield CfgBranch(
positive=positive,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
else:
pos_embeds, pos_mask, pos_pooled, pos_cond = (
prompt_data.get_cfg_branch_data(positive=True)
)
yield CfgBranch(
positive=True,
embeds=pos_embeds,
mask=pos_mask,
pooled=pos_pooled,
cond_latents=pos_cond,
)
neg_embeds, neg_mask, neg_pooled, neg_cond = (
prompt_data.get_cfg_branch_data(positive=False)
)
yield CfgBranch(
positive=False,
embeds=neg_embeds,
mask=neg_mask,
pooled=neg_pooled,
cond_latents=neg_cond,
)
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
if len(results) == 1:
positive, noise = results[0]
if self.cfg_parallel and self.is_last_stage:
# TODO(ciaran): try to remove
mx.eval(noise)
return self._exchange_and_apply_guidance(noise, positive)
return noise
noise_neg = next(n for p, n in results if not p)
noise_pos = next(n for p, n in results if p)
return self._apply_guidance(noise_pos, noise_neg)
def _exchange_and_apply_guidance(
self, noise: mx.array, is_positive: bool
) -> mx.array:
assert self.group is not None
assert self.cfg_peer_rank is not None
if is_positive:
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_neg)
noise_pos = noise
else:
noise_pos = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_pos)
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = noise
return self._apply_guidance(noise_pos, noise_neg)
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
scale = self._get_effective_guidance_scale()
assert scale is not None
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -598,9 +464,7 @@ class DiffusionRunner:
) -> mx.array:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif (
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
):
elif t < config.init_time_step + num_sync_steps:
return self._sync_pipeline_step(
t,
config,
@@ -624,29 +488,42 @@ class DiffusionRunner:
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
# Reset caches before each branch to ensure no state contamination
self._reset_all_caches()
needs_cfg = self.adapter.needs_cfg
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate([latents, latents], axis=0)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = latents
noise = self._forward_pass(
step_latents,
prompt_embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
conditioning_latents=cond_latents,
)
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=guidance_scale
)
noise = self._forward_pass(
latents,
branch.embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
conditioning_latents=branch.cond_latents,
)
results.append((branch.positive, noise))
noise = self._combine_cfg_results(results)
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
def _create_patches(
@@ -697,7 +574,7 @@ class DiffusionRunner:
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
@@ -709,17 +586,16 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
@@ -744,30 +620,27 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
assert self.next_pipeline_rank is not None
concatenated = mx.distributed.send(
concatenated, self.next_pipeline_rank, group=self.group
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
assert self.next_pipeline_rank is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
@@ -782,9 +655,8 @@ class DiffusionRunner:
)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
@@ -807,65 +679,75 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
cond_latents = branch.cond_latents
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
branch.embeds,
pooled_embeds,
branch.mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
step_latents = mx.concatenate(
[scaled_hidden_states, scaled_hidden_states], axis=0
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
prompt_embeds,
pooled_embeds,
encoder_mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise, timestep=t, latents=prev_latents
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(
hidden_states, self.first_pipeline_rank, group=self.group
)
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.last_pipeline_rank, group=self.group
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
@@ -884,10 +766,39 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
prev_patch_latents = [p for p in patch_latents]
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_mask)
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
prev_patch_latents = [p for p in patch_latents]
encoder_hidden_states: mx.array | None = None
for patch_idx in range(len(patch_latents)):
@@ -899,52 +810,31 @@ class DiffusionRunner:
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.last_pipeline_rank, group=self.group
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
results: list[tuple[bool, mx.array]] = []
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
for branch in self._get_cfg_branches(prompt_data):
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(branch.mask)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
branch.embeds,
config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=branch.embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=step_patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=prompt_embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise,
@@ -954,9 +844,7 @@ class DiffusionRunner:
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx],
self.first_pipeline_rank,
group=self.group,
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
@@ -996,12 +884,11 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
@@ -1010,7 +897,7 @@ class DiffusionRunner:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
@@ -1038,34 +925,29 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
assert self.next_pipeline_rank is not None
patch_concat = mx.distributed.send(
patch_concat, self.next_pipeline_rank, group=self.group
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
assert self.next_pipeline_rank is not None
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
@@ -1080,10 +962,7 @@ class DiffusionRunner:
)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None

View File

@@ -10,8 +10,11 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
CompletionTokensDetails,
FinishReason,
GenerationStats,
PromptTokensDetails,
Usage,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
@@ -24,7 +27,6 @@ from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_c
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -39,7 +41,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
) -> tuple[float, int]:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
@@ -50,7 +52,7 @@ def prefill(
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
return 0.0, 0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
@@ -85,7 +87,7 @@ def prefill(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
return tokens_per_sec, num_tokens
def warmup_inference(
@@ -133,10 +135,6 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
@@ -169,6 +167,8 @@ def mlx_generate(
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
logger.info(f"{is_bench=}")
# Currently we support chat-completion tasks only.
logger.debug(f"task_params: {task}")
@@ -204,7 +204,9 @@ def mlx_generate(
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
prefill_tps, prefill_tokens = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches
)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
@@ -212,28 +214,43 @@ def mlx_generate(
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
usage: Usage | None = None
in_thinking = False
reasoning_tokens = 0
think_start = tokenizer.think_start
think_end = tokenizer.think_end
for completion_tokens, out in enumerate(
stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
),
start=1,
):
generated_text_parts.append(out.text)
logger.info(out.text)
if think_start is not None and out.text == think_start:
in_thinking = True
elif think_end is not None and out.text == think_end:
in_thinking = False
if in_thinking:
reasoning_tokens += 1
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
@@ -245,11 +262,24 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),
completion_tokens_details=CompletionTokensDetails(
reasoning_tokens=reasoning_tokens
),
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
usage=usage,
)
if out.finish_reason is not None:
@@ -274,5 +304,3 @@ def mlx_generate(
else:
kv_prefix_cache.add_kv_cache(full_prompt, caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -67,8 +67,6 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -86,30 +84,6 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -538,3 +512,23 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -33,6 +33,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -115,8 +116,9 @@ class Worker:
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
async with create_task_group() as tg:
for runner in self.runners.values():
tg.start_soon(runner.shutdown)
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
@@ -220,15 +222,22 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await runner.start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
await runner.shutdown()
case CancelTask(cancelled_task_id=cancelled_task_id):
await self.runners[self._task_to_runner_id(task)].cancel_task(
cancelled_task_id
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -351,8 +360,6 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ChatCompletion,
ConnectToGroup,
CreateRunner,
@@ -59,7 +60,8 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _cancel_tasks(runners, tasks)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
)
@@ -270,7 +272,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -284,7 +286,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
received = len(input_chunk_buffer.get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -292,16 +294,31 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id, cancelled_task_id=task.task_id
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,6 +15,7 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -38,7 +39,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -1,5 +1,6 @@
import base64
import json
import math
import time
from collections.abc import Generator
from functools import cache
@@ -37,6 +38,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -61,7 +63,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
@@ -78,6 +80,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -86,6 +89,7 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -100,11 +104,15 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
model: Model | DistributedImageModel | None = None
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -113,6 +121,7 @@ def main(
)
with task_receiver as tasks:
for task in tasks:
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -157,7 +166,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -169,7 +178,7 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
model = initialize_image_model(bound_instance)
image_model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -177,8 +186,6 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -189,15 +196,30 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
t = time.perf_counter()
toks = warmup_inference(
model=model,
model=inference_model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / (time.perf_counter() - t)), 100
)
if group is not None:
check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]), group=group
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
@@ -205,8 +227,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -225,8 +247,9 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert model and not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
assert check_for_cancel_every
assert task_params.messages[0].content is not None
try:
@@ -237,7 +260,7 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
model=inference_model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
@@ -261,11 +284,11 @@ def main(
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -277,9 +300,22 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
completion_tokens = 0
tokens_since_last_cancel_check = 0
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
tokens_since_last_cancel_check = 0
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
completion_tokens += 1
if (
device_rank == 0
and response.finish_reason == "error"
@@ -307,6 +343,7 @@ def main(
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
usage=response.usage,
finish_reason=response.finish_reason,
stats=response.stats,
),
@@ -320,6 +357,7 @@ def main(
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
),
)
)
@@ -344,7 +382,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -358,11 +396,12 @@ def main(
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
@@ -388,11 +427,7 @@ def main(
image_index += 1
# can we make this more explicit?
except Exception as e:
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -410,7 +445,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -422,11 +457,12 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
@@ -451,11 +487,7 @@ def main(
)
image_index += 1
except Exception as e:
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -490,7 +522,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
@@ -545,10 +577,10 @@ def parse_gpt_oss(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
]
],
usage=response.usage,
)
tool_arg_parts = []
break
current_tool_name = recipient
# If inside a tool call, accumulate arguments
@@ -694,7 +726,7 @@ def parse_tool_calls(
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
except (
json.JSONDecodeError,

View File

@@ -49,10 +49,12 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_tg: TaskGroup | None = field(default=None, init=False)
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -63,8 +65,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -72,6 +74,7 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -86,6 +89,7 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -93,44 +97,47 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
tg.start_soon(self._forward_events)
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
if not self.runner_process.is_alive():
return
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.close()
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 10)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
def shutdown(self):
assert self._tg
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def shutdown(self):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
if task.task_id in self.completed or task.task_id in self.pending:
logger.info(f"Skipping invalid task {task} as it has already been queued")
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -140,7 +147,13 @@ class RunnerSupervisor:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
logger.info(f"Finished task {task}")
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
async def _forward_events(self):
with self._ev_recv as events:
@@ -206,4 +219,4 @@ class RunnerSupervisor:
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
self.shutdown()
await self.shutdown()

View File

@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
@@ -182,6 +182,8 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
text="hi",
token_id=0,
finish_reason="stop",
usage=None,
stats=None,
),
)

View File

@@ -0,0 +1,18 @@
{
"$schema": "https://opencode.ai/config.json",
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
"provider": {
"exo": {
"api": "http://localhost:52415/v1",
"models": {
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
"name": "GPT OSS 120B",
"limit": {
"context": 32768,
"output": 8192
}
}
}
}
}
}