mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
rename ibv to jaccl inline with mlx
This commit is contained in:
@@ -26,7 +26,7 @@ from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxIbvInstance,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
|
||||
@@ -105,7 +105,7 @@ def get_instance_placements_after_create(
|
||||
|
||||
# TODO: Single node instances
|
||||
match command.instance_meta:
|
||||
case InstanceMeta.MlxIbv:
|
||||
case InstanceMeta.MlxJaccl:
|
||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
||||
selected_cycle,
|
||||
cycle_digraph,
|
||||
@@ -114,7 +114,7 @@ def get_instance_placements_after_create(
|
||||
selected_cycle,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
)
|
||||
target_instances[instance_id] = MlxIbvInstance(
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
|
||||
@@ -19,7 +19,7 @@ from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxIbvInstance,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import ShardAssignments
|
||||
@@ -422,7 +422,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
cic = CreateInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
instance_meta=InstanceMeta.MlxIbv,
|
||||
instance_meta=InstanceMeta.MlxJaccl,
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta,
|
||||
min_nodes=1,
|
||||
@@ -434,7 +434,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assert isinstance(instance, MlxIbvInstance)
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.ibv_coordinator is not None
|
||||
|
||||
@@ -13,7 +13,7 @@ class InstanceId(Id):
|
||||
|
||||
class InstanceMeta(str, Enum):
|
||||
MlxRing = "MlxRing"
|
||||
MlxIbv = "MlxIbv"
|
||||
MlxJaccl = "MlxJaccl"
|
||||
|
||||
|
||||
class BaseInstance(TaggedModel):
|
||||
@@ -28,13 +28,13 @@ class MlxRingInstance(BaseInstance):
|
||||
hosts: list[Host]
|
||||
|
||||
|
||||
class MlxIbvInstance(BaseInstance):
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
ibv_devices: list[list[str | None]]
|
||||
ibv_coordinator: str
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxIbvInstance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
|
||||
|
||||
class BoundInstance(CamelCaseModel):
|
||||
|
||||
@@ -32,7 +32,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
MlxIbvInstance,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.shards import (
|
||||
@@ -128,7 +128,7 @@ def mlx_distributed_init(
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxIbvInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
|
||||
case MlxJacclInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
@@ -4,7 +4,7 @@ import loguru
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxIbvInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
|
||||
logger: "loguru.Logger"
|
||||
@@ -21,7 +21,7 @@ def entrypoint(
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxIbvInstance)
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
@@ -37,7 +37,7 @@ Integration tests:
|
||||
|
||||
2. Test that node count does not affect inference result (per-configuration)
|
||||
- Llama on 1 node, and on 2 nodes returns the same result, given temperature 0 and set seed.
|
||||
- Do for all configurations (Ring/Ibv, Pipeline/Tensor)
|
||||
- Do for all configurations (Ring/Jaccl, Pipeline/Tensor)
|
||||
|
||||
3. Test supervisor catches exceptions gracefully
|
||||
- Timeouts
|
||||
|
||||
Reference in New Issue
Block a user