Compare commits

..

6 Commits

Author SHA1 Message Date
ciaranbor
77fbffcebe Ensure unique model Id for each quant 2026-01-23 20:08:45 +00:00
ciaranbor
4feb3cde86 Enable image model quantization 2026-01-23 19:58:50 +00:00
ciaranbor
23fd37fe4d Add FLUX.1-Krea-dev model (#1269)
## Why It Works

Same implementation as FLUX.1-dev, just different weights
2026-01-23 19:48:24 +00:00
Alex Cheema
d229df38f9 Fix placement filter to use subset matching instead of exact match (#1265)
## Motivation

When using the dashboard's instance placement filter (clicking nodes in
the topology), it was filtering to placements that use exactly the
selected nodes. This isn't the expected behavior - users want to see
placements that include all selected nodes, but may also include
additional nodes.

For example, selecting nodes [A, B] should show placements using [A, B],
[A, B, C], [A, B, C, D], etc. - not just [A, B].

## Changes

- Added `required_nodes` parameter to `place_instance()` in
`placement.py`
- Filter cycles early in placement to only those containing all required
nodes (subset matching)
- Simplified `api.py` by removing the subgraph topology filtering and
passing `required_nodes` directly to placement
- Renamed internal `node_ids` variable to `placement_node_ids` to avoid
shadowing the parameter

## Why It Works

By filtering cycles at the placement level using
`required_nodes.issubset(cycle.node_ids)`, we ensure that only cycles
containing all the user-selected nodes are considered. This happens
early in the placement algorithm, so we don't waste time computing
placements that would be filtered out later.

## Test Plan

### Manual Testing
- Select nodes in the dashboard topology view
- Verify that placements shown include all selected nodes (but may
include additional nodes)
- Verify that placements not containing the selected nodes are filtered
out

### Automated Testing
- Existing placement tests pass
- `uv run pytest src/exo/master/tests/ -v` - 37 tests pass

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 19:40:31 +00:00
Alex Cheema
8a595fee2f Fix Thunderbolt bridge cycle detection to include 2-node cycles (#1261)
## Motivation

Packet storms occur with Thunderbolt bridge enabled on 2 machines
connected by Thunderbolt, not just 3+ node cycles as previously assumed.
The cycle detection was too conservative and missed this case.

## Changes

- Changed the minimum cycle length from >2 (3+ nodes) to >=2 (2+ nodes)
- Updated the early return threshold from `< 3` to `< 2` enabled nodes
- Updated docstring to reflect the new behavior

## Why It Works

A Thunderbolt bridge loop between just 2 machines can still create
broadcast storms when both have the bridge enabled. The previous
threshold of 3+ was based on an incorrect assumption that 2-node
connections wouldn't cause this problem.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
- Tested with 2 machines connected via Thunderbolt with bridge enabled
- Confirmed packet storms occur in this configuration
- Verified the fix correctly detects and handles 2-node cycles

### Automated Testing
- Existing topology tests cover cycle detection logic

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 19:34:48 +00:00
ciaranbor
c8571a17a3 Fix guidance (#1264)
## Motivation

Previously, we only handled user-provided guidance parameter for CFG
models.

## Changes

Just pass the parameter to model setup
2026-01-23 19:13:45 +00:00
20 changed files with 293 additions and 206 deletions

View File

@@ -5,16 +5,16 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.

View File

@@ -26,7 +26,7 @@ dependencies = [
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.14.2",
"mflux==0.15.4",
"python-multipart>=0.0.21",
]

View File

@@ -32,6 +32,7 @@ from exo.download.huggingface_utils import (
get_hf_token,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
@@ -481,6 +482,11 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
return ["*"]
def is_image_model(shard: ShardMetadata) -> bool:
tasks = shard.model_card.tasks
return ModelTask.TextToImage in tasks or ModelTask.ImageToImage in tasks
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path):
@@ -522,6 +528,15 @@ async def download_shard(
file_list, allow_patterns=allow_patterns, key=lambda x: x.path
)
)
# For image models, skip root-level safetensors files since weights
# are stored in component subdirectories (e.g., transformer/, vae/)
if is_image_model(shard):
filtered_file_list = [
f
for f in filtered_file_list
if "/" in f.path or not f.path.endswith(".safetensors")
]
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def on_progress_wrapper(

View File

@@ -75,7 +75,6 @@ from exo.shared.types.chunks import (
ToolCallChunk,
)
from exo.shared.types.commands import (
CancelTask,
ChatCompletion,
Command,
CreateInstance,
@@ -357,14 +356,9 @@ class API:
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
required_nodes = set(node_ids) if node_ids else None
# Create filtered topology if node_ids specified
if node_ids and len(node_ids) > 0:
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
else:
topology = self.state.topology
if len(list(topology.list_nodes())) == 0:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
@@ -377,7 +371,9 @@ class API:
instance_combinations.extend(
[
(sharding, instance_meta, i)
for i in range(1, len(list(topology.list_nodes())) + 1)
for i in range(
1, len(list(self.state.topology.list_nodes())) + 1
)
]
)
# TODO: PDD
@@ -395,8 +391,9 @@ class API:
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=topology,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
@@ -435,14 +432,16 @@ class API:
instance = new_instances[0]
shard_assignments = instance.shard_assignments
node_ids = list(shard_assignments.node_to_runner.keys())
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if node_ids:
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
@@ -450,7 +449,7 @@ class API:
model_card.model_id,
sharding,
instance_meta,
len(node_ids),
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
@@ -462,7 +461,14 @@ class API:
error=None,
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
return PlacementPreviewResponse(previews=previews)
@@ -502,10 +508,12 @@ class API:
break
except anyio.get_cancelled_exc_class():
command = CancelTask(cancelled_command_id=command_id)
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
finally:
command = TaskFinished(finished_command_id=command_id)
@@ -893,10 +901,6 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = CancelTask(cancelled_command_id=command_id)
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -978,10 +982,6 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = CancelTask(cancelled_command_id=command_id)
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))

View File

@@ -12,7 +12,6 @@ from exo.master.placement import (
)
from exo.shared.apply import apply
from exo.shared.types.commands import (
CancelTask,
ChatCompletion,
CreateInstance,
DeleteInstance,
@@ -36,7 +35,6 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -280,15 +278,6 @@ class Master:
chunk=chunk,
)
)
case CancelTask():
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=self.command_task_mapping[
command.cancelled_command_id
],
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -297,7 +286,10 @@ class Master:
]
)
)
del self.command_task_mapping[command.finished_command_id]
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -54,9 +54,18 @@ def place_instance(
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
required_nodes: set[NodeId] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
# Filter to cycles containing all required nodes (subset matching)
if required_nodes:
candidate_cycles = [
cycle
for cycle in candidate_cycles
if required_nodes.issubset(cycle.node_ids)
]
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
)

View File

@@ -40,6 +40,7 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
quantization: int | None = None
@field_validator("tasks", mode="before")
@classmethod
@@ -413,7 +414,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
}
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
@@ -428,7 +429,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -442,7 +443,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -458,7 +459,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
storage_size=Memory.from_bytes(23802816640 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
@@ -470,7 +471,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -484,7 +485,49 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -501,7 +544,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
@@ -509,10 +552,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -535,7 +578,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
@@ -543,10 +586,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -568,6 +611,93 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
),
}
def _create_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
quantizations = [8, 6, 5, 4, 3]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
quantization=None,
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
model_id = base_card.model_id + f"-{quant}bit"
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=ModelId(model_id),
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
quantization=quant,
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)

View File

@@ -248,8 +248,8 @@ class Topology:
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
Only returns cycles with >=2 nodes (2+ machines in a loop), as
1 node doesn't cause the broadcast storm problem.
"""
enabled_nodes = {
node_id
@@ -257,7 +257,7 @@ class Topology:
if status.enabled
}
if len(enabled_nodes) < 3:
if len(enabled_nodes) < 2:
return []
thunderbolt_ips = _get_ips_with_interface_type(
@@ -288,7 +288,7 @@ class Topology:
return [
[graph[idx] for idx in cycle]
for cycle in rx.simple_cycles(graph)
if len(cycle) > 2
if len(cycle) >= 2
]

View File

@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class CancelTask(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -88,7 +84,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| CancelTask
| TaskFinished
| SendInputChunk
)

View File

@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -61,10 +60,6 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -92,7 +87,6 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -71,8 +71,10 @@ class DistributedImageModel:
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_card.model_id
model_card = bound_instance.bound_shard.model_card
model_id = model_card.model_id
model_path = build_model_path(model_id)
quantize = model_card.quantization
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
@@ -93,6 +95,7 @@ class DistributedImageModel:
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
quantize=quantize,
)
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
@@ -140,6 +143,7 @@ class DistributedImageModel:
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config, # pyright: ignore[reportAny]
guidance=guidance_override if guidance_override is not None else 4.0,
)
num_sync_steps = self._config.get_num_sync_steps(steps)

View File

@@ -33,6 +33,7 @@ _ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,

View File

@@ -23,6 +23,7 @@ from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -89,6 +90,10 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
@@ -181,3 +186,5 @@ def mlx_generate(
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -70,6 +70,8 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -87,6 +89,30 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -510,33 +536,3 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_all(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() == group.size()
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -33,7 +33,6 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -230,10 +229,6 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case CancelTask(cancelled_task_id=cancelled_task_id):
await self.runners[self._task_to_runner_id(task)].cancel_task(
cancelled_task_id
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -356,6 +351,8 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -2,11 +2,8 @@
from collections.abc import Mapping, Sequence
from loguru import logger
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ChatCompletion,
ConnectToGroup,
CreateRunner,
@@ -51,8 +48,8 @@ def plan(
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] = {},
input_chunk_counts: Mapping[CommandId, int] = {},
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -63,7 +60,6 @@ def plan(
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _cancel_tasks(runners, tasks)
)
@@ -274,7 +270,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -288,7 +284,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len(input_chunk_buffer.get(cmd_id, {}))
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -296,33 +292,16 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
logger.info(f"{task.task_id} is cancelled!")
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
logger.info(f"{task.task_id} is mine!")
if task.task_id in runner.cancelled:
logger.info(f"{task.task_id} already cancelled")
continue
logger.info(f"cancelling {task.task_id}")
CancelTask(instance_id=task.instance_id, cancelled_task_id=task.task_id)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,7 +15,6 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -39,7 +38,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -1,5 +1,4 @@
import base64
import contextlib
import json
import time
from collections.abc import Generator
@@ -38,7 +37,6 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -64,7 +62,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, WouldBlock
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
generate_image,
@@ -79,7 +77,6 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -88,7 +85,6 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -103,11 +99,8 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
@@ -118,10 +111,6 @@ def main(
)
with task_receiver as tasks:
for task in tasks:
with contextlib.suppress(WouldBlock):
cancelled_tasks.add(cancel_receiver.receive_nowait())
if task.task_id in cancelled_tasks:
continue
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -166,7 +155,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
inference_model, tokenizer = load_mlx_items(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -176,7 +165,7 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
image_model = initialize_image_model(bound_instance)
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -185,6 +174,8 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -195,11 +186,11 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert inference_model
assert not isinstance(model, DistributedImageModel)
assert tokenizer
toks = warmup_inference(
model=inference_model,
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -211,8 +202,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert image_model
image = warmup_image_generator(model=image_model)
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -231,7 +222,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert inference_model
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
assert task_params.messages[0].content is not None
@@ -243,14 +234,14 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=inference_model,
model=model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(inference_model, GptOssModel):
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
@@ -281,12 +272,6 @@ def main(
)
for response in mlx_generator:
with contextlib.suppress(WouldBlock):
cancelled_tasks.add(cancel_receiver.receive_nowait())
want_to_cancel = task.task_id in cancelled_tasks
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
if (
@@ -353,7 +338,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
assert isinstance(model, DistributedImageModel)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -367,9 +352,7 @@ def main(
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
@@ -416,7 +399,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
assert isinstance(model, DistributedImageModel)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -428,9 +411,7 @@ def main(
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
@@ -493,7 +474,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
del model, tokenizer, group
mx.clear_cache()
import gc

View File

@@ -49,12 +49,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -65,8 +63,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -74,7 +72,6 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -89,7 +86,6 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -135,7 +131,6 @@ class RunnerSupervisor:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -145,13 +140,7 @@ class RunnerSupervisor:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:

27
uv.lock generated
View File

@@ -412,7 +412,7 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = ">=0.14.2" },
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
@@ -458,16 +458,6 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"
@@ -997,7 +987,7 @@ wheels = [
[[package]]
name = "mflux"
version = "0.15.3"
version = "0.15.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1023,9 +1013,9 @@ dependencies = [
{ name = "twine", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/23/c5/dd12e16714702255d89b7ccc6f217c405a9fdcf2af950a2236892c50a219/mflux-0.15.3.tar.gz", hash = "sha256:e32ea66a81aad4f77eea2415b17c27fc3d9ce662a842565c62871ff570f4ef2f", size = 740701, upload-time = "2026-01-19T22:54:59.066Z" }
sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/95322db7a865e4df6bad108b1c99aa7fbe211aac3f298f3ad696c2744a39/mflux-0.15.4.tar.gz", hash = "sha256:138e1aedae86e13eafeb8faec017945fcdcca42c3234daabcd81a83c9a202ace", size = 741228, upload-time = "2026-01-20T15:39:26.807Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cf/9f/a673ee12877a0943a4059c51b5beb6cf909c92f25384365cf8beeb475159/mflux-0.15.3-py3-none-any.whl", hash = "sha256:631cfcc038f27e9bd0ff76c25c2bc7373562b8f64cf0ce961fc268a246fa699e", size = 987270, upload-time = "2026-01-19T22:54:57.155Z" },
{ url = "https://files.pythonhosted.org/packages/8e/be/81cf4ce2d1933b9b210c028a05ac95e958008c0d43e377a5f2757b7f2d4d/mflux-0.15.4-py3-none-any.whl", hash = "sha256:f04d9b1d7c5cd67880f483ab29fb2097648a25459eef9c5ee6480fad46de5e82", size = 987644, upload-time = "2026-01-20T15:39:24.817Z" },
]
[[package]]
@@ -2227,6 +2217,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
]
[[package]]
name = "torch"
version = "2.9.1"