Compare commits

...

2 Commits

Author SHA1 Message Date
Evan
525aaab808 fix 2026-02-26 13:42:07 +00:00
ciaranbor
eaed92952c Use tmpdir for coordination file (#1624)
## Motivation

Coordination files for MLX distributed init were written to the current
working directory (./hosts_*.json)

## Changes

- Move coordination file creation to a tempfile.TemporaryDirectory(),
which auto-cleans on context manager exit
2026-02-26 10:59:36 +00:00
3 changed files with 42 additions and 19 deletions

View File

@@ -524,15 +524,15 @@ class API:
if ( if (
model_card.model_id, model_card.model_id,
sharding, instance.sharding(),
instance_meta, instance.instance_meta(),
len(placement_node_ids), len(placement_node_ids),
) not in seen: ) not in seen:
previews.append( previews.append(
PlacementPreview( PlacementPreview(
model_id=model_card.model_id, model_id=model_card.model_id,
sharding=sharding, sharding=instance.sharding(),
instance_meta=instance_meta, instance_meta=instance.instance_meta(),
instance=instance, instance=instance,
memory_delta_by_node=memory_delta_by_node or None, memory_delta_by_node=memory_delta_by_node or None,
error=None, error=None,
@@ -541,8 +541,8 @@ class API:
seen.add( seen.add(
( (
model_card.model_id, model_card.model_id,
sharding, instance.sharding(),
instance_meta, instance.instance_meta(),
len(placement_node_ids), len(placement_node_ids),
) )
) )

View File

@@ -4,7 +4,13 @@ from pydantic import model_validator
from exo.shared.models.model_cards import ModelTask from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import Host, Id, NodeId from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
Sharding,
ShardMetadata,
TensorShardMetadata,
)
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
def shard(self, runner_id: RunnerId) -> ShardMetadata | None: def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None) return self.shard_assignments.runner_to_shard.get(runner_id, None)
@staticmethod
def instance_meta() -> InstanceMeta: ...
def sharding(self) -> Sharding:
if all(
isinstance(sm, PipelineShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Pipeline
if all(
isinstance(sm, TensorShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Tensor
raise ValueError("shard metadata malformed")
class MlxRingInstance(BaseInstance): class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]] hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int ephemeral_port: int
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxRing
class MlxJacclInstance(BaseInstance): class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]] jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str] jaccl_coordinators: dict[NodeId, str]
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxJaccl
# TODO: Single node instance # TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance Instance = MlxRingInstance | MlxJacclInstance

View File

@@ -2,6 +2,7 @@ import json
import os import os
import re import re
import sys import sys
import tempfile
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, cast from typing import Any, cast
@@ -98,14 +99,13 @@ def mlx_distributed_init(
rank = bound_instance.bound_shard.device_rank rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}") logger.info(f"Starting initialization for rank {rank}")
coordination_file = None with tempfile.TemporaryDirectory() as tmpdir:
try: coordination_file = str(
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
# TODO: singleton instances # TODO: singleton instances
match bound_instance.instance: match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_): case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id] hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json() hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
@@ -128,9 +128,6 @@ def mlx_distributed_init(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices)) jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
) )
# Use RDMA connectivity matrix # Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices) jaccl_devices_json = json.dumps(jaccl_devices)
with open(coordination_file, "w") as f: with open(coordination_file, "w") as f:
@@ -150,10 +147,6 @@ def mlx_distributed_init(
logger.info(f"Rank {rank} mlx distributed initialization complete") logger.info(f"Rank {rank} mlx distributed initialization complete")
return group return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx( def initialize_mlx(