Compare commits

..

1 Commits

Author SHA1 Message Date
ciaranbor
6f003759c4 Use tmpdir for coordination file 2026-02-25 18:57:57 +00:00
3 changed files with 12 additions and 49 deletions

View File

@@ -524,15 +524,15 @@ class API:
if (
model_card.model_id,
instance.sharding(),
instance.instance_meta(),
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=instance.sharding(),
instance_meta=instance.instance_meta(),
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -541,8 +541,8 @@ class API:
seen.add(
(
model_card.model_id,
instance.sharding(),
instance.instance_meta(),
sharding,
instance_meta,
len(placement_node_ids),
)
)

View File

@@ -4,13 +4,7 @@ from pydantic import model_validator
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
Sharding,
ShardMetadata,
TensorShardMetadata,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -30,40 +24,16 @@ class BaseInstance(TaggedModel):
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)
@staticmethod
def instance_meta() -> InstanceMeta: ...
def sharding(self) -> Sharding:
if all(
isinstance(sm, PipelineShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Pipeline
if all(
isinstance(sm, TensorShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Tensor
raise ValueError("shard metadata malformed")
class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxRing
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxJaccl
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance

View File

@@ -2,6 +2,7 @@ import json
import os
import re
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, cast
@@ -98,14 +99,13 @@ def mlx_distributed_init(
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
coordination_file = None
try:
with tempfile.TemporaryDirectory() as tmpdir:
coordination_file = str(
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
@@ -128,9 +128,6 @@ def mlx_distributed_init(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(coordination_file, "w") as f:
@@ -150,10 +147,6 @@ def mlx_distributed_init(
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx(