mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
1 Commits
nid-persis
...
consistent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8e5d3710f |
@@ -524,15 +524,15 @@ class API:
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
sharding=instance.sharding(),
|
||||
instance_meta=instance.instance_meta(),
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
@@ -541,8 +541,8 @@ class API:
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,7 +4,13 @@ from pydantic import model_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelTask
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta: ...
|
||||
|
||||
def sharding(self) -> Sharding:
|
||||
if all(
|
||||
isinstance(sm, PipelineShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Pipeline
|
||||
if all(
|
||||
isinstance(sm, TensorShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Tensor
|
||||
raise ValueError("shard metadata malformed")
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxRing
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxJaccl
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
|
||||
Reference in New Issue
Block a user