mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-26 03:06:05 -05:00
Compare commits
1 Commits
consistent
...
ciaran/use
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f003759c4 |
@@ -524,15 +524,15 @@ class API:
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=instance.sharding(),
|
||||
instance_meta=instance.instance_meta(),
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
@@ -541,8 +541,8 @@ class API:
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,13 +4,7 @@ from pydantic import model_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelTask
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -30,40 +24,16 @@ class BaseInstance(TaggedModel):
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta: ...
|
||||
|
||||
def sharding(self) -> Sharding:
|
||||
if all(
|
||||
isinstance(sm, PipelineShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Pipeline
|
||||
if all(
|
||||
isinstance(sm, TensorShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Tensor
|
||||
raise ValueError("shard metadata malformed")
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxRing
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxJaccl
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
@@ -98,14 +99,13 @@ def mlx_distributed_init(
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
coordination_file = None
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
coordination_file = str(
|
||||
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
@@ -128,9 +128,6 @@ def mlx_distributed_init(
|
||||
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
|
||||
)
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
@@ -150,10 +147,6 @@ def mlx_distributed_init(
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
|
||||
Reference in New Issue
Block a user