mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-26 03:06:05 -05:00
Compare commits
1 Commits
remove-pyt
...
consistent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8e5d3710f |
@@ -524,15 +524,15 @@ class API:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
model_card.model_id,
|
model_card.model_id,
|
||||||
sharding,
|
instance.sharding(),
|
||||||
instance_meta,
|
instance.instance_meta(),
|
||||||
len(placement_node_ids),
|
len(placement_node_ids),
|
||||||
) not in seen:
|
) not in seen:
|
||||||
previews.append(
|
previews.append(
|
||||||
PlacementPreview(
|
PlacementPreview(
|
||||||
model_id=model_card.model_id,
|
model_id=model_card.model_id,
|
||||||
sharding=sharding,
|
sharding=instance.sharding(),
|
||||||
instance_meta=instance_meta,
|
instance_meta=instance.instance_meta(),
|
||||||
instance=instance,
|
instance=instance,
|
||||||
memory_delta_by_node=memory_delta_by_node or None,
|
memory_delta_by_node=memory_delta_by_node or None,
|
||||||
error=None,
|
error=None,
|
||||||
@@ -541,8 +541,8 @@ class API:
|
|||||||
seen.add(
|
seen.add(
|
||||||
(
|
(
|
||||||
model_card.model_id,
|
model_card.model_id,
|
||||||
sharding,
|
instance.sharding(),
|
||||||
instance_meta,
|
instance.instance_meta(),
|
||||||
len(placement_node_ids),
|
len(placement_node_ids),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,13 @@ from pydantic import model_validator
|
|||||||
|
|
||||||
from exo.shared.models.model_cards import ModelTask
|
from exo.shared.models.model_cards import ModelTask
|
||||||
from exo.shared.types.common import Host, Id, NodeId
|
from exo.shared.types.common import Host, Id, NodeId
|
||||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||||
|
from exo.shared.types.worker.shards import (
|
||||||
|
PipelineShardMetadata,
|
||||||
|
Sharding,
|
||||||
|
ShardMetadata,
|
||||||
|
TensorShardMetadata,
|
||||||
|
)
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||||
|
|
||||||
|
|
||||||
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
|
|||||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def instance_meta() -> InstanceMeta: ...
|
||||||
|
|
||||||
|
def sharding(self) -> Sharding:
|
||||||
|
if all(
|
||||||
|
isinstance(sm, PipelineShardMetadata)
|
||||||
|
for sm in self.shard_assignments.runner_to_shard.values()
|
||||||
|
):
|
||||||
|
return Sharding.Pipeline
|
||||||
|
if all(
|
||||||
|
isinstance(sm, TensorShardMetadata)
|
||||||
|
for sm in self.shard_assignments.runner_to_shard.values()
|
||||||
|
):
|
||||||
|
return Sharding.Tensor
|
||||||
|
raise ValueError("shard metadata malformed")
|
||||||
|
|
||||||
|
|
||||||
class MlxRingInstance(BaseInstance):
|
class MlxRingInstance(BaseInstance):
|
||||||
hosts_by_node: dict[NodeId, list[Host]]
|
hosts_by_node: dict[NodeId, list[Host]]
|
||||||
ephemeral_port: int
|
ephemeral_port: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def instance_meta() -> InstanceMeta:
|
||||||
|
return InstanceMeta.MlxRing
|
||||||
|
|
||||||
|
|
||||||
class MlxJacclInstance(BaseInstance):
|
class MlxJacclInstance(BaseInstance):
|
||||||
jaccl_devices: list[list[str | None]]
|
jaccl_devices: list[list[str | None]]
|
||||||
jaccl_coordinators: dict[NodeId, str]
|
jaccl_coordinators: dict[NodeId, str]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def instance_meta() -> InstanceMeta:
|
||||||
|
return InstanceMeta.MlxJaccl
|
||||||
|
|
||||||
|
|
||||||
# TODO: Single node instance
|
# TODO: Single node instance
|
||||||
Instance = MlxRingInstance | MlxJacclInstance
|
Instance = MlxRingInstance | MlxJacclInstance
|
||||||
|
|||||||
Reference in New Issue
Block a user