Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
525aaab808 fix 2026-02-26 13:42:07 +00:00
2 changed files with 37 additions and 7 deletions

View File

@@ -524,15 +524,15 @@ class API:
if (
model_card.model_id,
sharding,
instance_meta,
instance.sharding(),
instance.instance_meta(),
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
sharding=instance.sharding(),
instance_meta=instance.instance_meta(),
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -541,8 +541,8 @@ class API:
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
instance.sharding(),
instance.instance_meta(),
len(placement_node_ids),
)
)

View File

@@ -4,7 +4,13 @@ from pydantic import model_validator
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
Sharding,
ShardMetadata,
TensorShardMetadata,
)
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)
@staticmethod
def instance_meta() -> InstanceMeta: ...
def sharding(self) -> Sharding:
if all(
isinstance(sm, PipelineShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Pipeline
if all(
isinstance(sm, TensorShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Tensor
raise ValueError("shard metadata malformed")
class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxRing
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxJaccl
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance