Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
e8e5d3710f fix 2026-02-25 18:52:43 +00:00
2 changed files with 37 additions and 7 deletions

View File

@@ -524,15 +524,15 @@ class API:
if ( if (
model_card.model_id, model_card.model_id,
sharding, instance.sharding(),
instance_meta, instance.instance_meta(),
len(placement_node_ids), len(placement_node_ids),
) not in seen: ) not in seen:
previews.append( previews.append(
PlacementPreview( PlacementPreview(
model_id=model_card.model_id, model_id=model_card.model_id,
sharding=sharding, sharding=instance.sharding(),
instance_meta=instance_meta, instance_meta=instance.instance_meta(),
instance=instance, instance=instance,
memory_delta_by_node=memory_delta_by_node or None, memory_delta_by_node=memory_delta_by_node or None,
error=None, error=None,
@@ -541,8 +541,8 @@ class API:
seen.add( seen.add(
( (
model_card.model_id, model_card.model_id,
sharding, instance.sharding(),
instance_meta, instance.instance_meta(),
len(placement_node_ids), len(placement_node_ids),
) )
) )

View File

@@ -4,7 +4,13 @@ from pydantic import model_validator
from exo.shared.models.model_cards import ModelTask from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import Host, Id, NodeId from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
Sharding,
ShardMetadata,
TensorShardMetadata,
)
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
def shard(self, runner_id: RunnerId) -> ShardMetadata | None: def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None) return self.shard_assignments.runner_to_shard.get(runner_id, None)
@staticmethod
def instance_meta() -> InstanceMeta: ...
def sharding(self) -> Sharding:
if all(
isinstance(sm, PipelineShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Pipeline
if all(
isinstance(sm, TensorShardMetadata)
for sm in self.shard_assignments.runner_to_shard.values()
):
return Sharding.Tensor
raise ValueError("shard metadata malformed")
class MlxRingInstance(BaseInstance): class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]] hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int ephemeral_port: int
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxRing
class MlxJacclInstance(BaseInstance): class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]] jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str] jaccl_coordinators: dict[NodeId, str]
@staticmethod
def instance_meta() -> InstanceMeta:
return InstanceMeta.MlxJaccl
# TODO: Single node instance # TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance Instance = MlxRingInstance | MlxJacclInstance