From 9e0a1c23ef2c6074ea4a20ca3d92a480f7fadfbe Mon Sep 17 00:00:00 2001 From: Evan Quiney Date: Fri, 5 Dec 2025 16:42:43 +0000 Subject: [PATCH] rename ibv to jaccl inline with mlx --- src/exo/master/placement.py | 6 +++--- src/exo/master/tests/test_placement.py | 6 +++--- src/exo/shared/types/worker/instances.py | 6 +++--- src/exo/worker/engines/mlx/utils_mlx.py | 4 ++-- src/exo/worker/runner/bootstrap.py | 4 ++-- src/exo/worker/tests/TODO.tests | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index e3c0adbc..98742924 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -26,7 +26,7 @@ from exo.shared.types.worker.instances import ( Instance, InstanceId, InstanceMeta, - MlxIbvInstance, + MlxJacclInstance, MlxRingInstance, ) @@ -105,7 +105,7 @@ def get_instance_placements_after_create( # TODO: Single node instances match command.instance_meta: - case InstanceMeta.MlxIbv: + case InstanceMeta.MlxJaccl: mlx_ibv_devices = get_mlx_ibv_devices_matrix( selected_cycle, cycle_digraph, @@ -114,7 +114,7 @@ def get_instance_placements_after_create( selected_cycle, coordinator_port=random_ephemeral_port(), ) - target_instances[instance_id] = MlxIbvInstance( + target_instances[instance_id] = MlxJacclInstance( instance_id=instance_id, shard_assignments=shard_assignments, ibv_devices=mlx_ibv_devices, diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 0eb7bd67..95cb33bc 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -19,7 +19,7 @@ from exo.shared.types.worker.instances import ( Instance, InstanceId, InstanceMeta, - MlxIbvInstance, + MlxJacclInstance, MlxRingInstance, ) from exo.shared.types.worker.runners import ShardAssignments @@ -422,7 +422,7 @@ def test_tensor_rdma_backend_connectivity_matrix( cic = CreateInstance( sharding=Sharding.Tensor, - instance_meta=InstanceMeta.MlxIbv, + instance_meta=InstanceMeta.MlxJaccl, command_id=CommandId(), model_meta=model_meta, min_nodes=1, @@ -434,7 +434,7 @@ def test_tensor_rdma_backend_connectivity_matrix( instance_id = list(placements.keys())[0] instance = placements[instance_id] - assert isinstance(instance, MlxIbvInstance) + assert isinstance(instance, MlxJacclInstance) assert instance.ibv_devices is not None assert instance.ibv_coordinator is not None diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index b68e60a4..e36c4fb0 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -13,7 +13,7 @@ class InstanceId(Id): class InstanceMeta(str, Enum): MlxRing = "MlxRing" - MlxIbv = "MlxIbv" + MlxJaccl = "MlxJaccl" class BaseInstance(TaggedModel): @@ -28,13 +28,13 @@ class MlxRingInstance(BaseInstance): hosts: list[Host] -class MlxIbvInstance(BaseInstance): +class MlxJacclInstance(BaseInstance): ibv_devices: list[list[str | None]] ibv_coordinator: str # TODO: Single node instance -Instance = MlxRingInstance | MlxIbvInstance +Instance = MlxRingInstance | MlxJacclInstance class BoundInstance(CamelCaseModel): diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 59ad30a9..dc6d1e45 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -32,7 +32,7 @@ from exo.shared.types.memory import Memory from exo.shared.types.tasks import ChatCompletionTaskParams from exo.shared.types.worker.instances import ( BoundInstance, - MlxIbvInstance, + MlxJacclInstance, MlxRingInstance, ) from exo.shared.types.worker.shards import ( @@ -128,7 +128,7 @@ def mlx_distributed_init( os.environ["MLX_RING_VERBOSE"] = "1" group = mx.distributed.init(backend="ring", strict=True) - case MlxIbvInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator): + case MlxJacclInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator): # Use RDMA connectivity matrix devices_file = f"./hosts_{rank}.json" ibv_devices_json = json.dumps(ibv_devices) diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index 4f9f6f28..24d30cb8 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -4,7 +4,7 @@ import loguru from exo.shared.types.events import Event from exo.shared.types.tasks import Task -from exo.shared.types.worker.instances import BoundInstance, MlxIbvInstance +from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance from exo.utils.channels import MpReceiver, MpSender logger: "loguru.Logger" @@ -21,7 +21,7 @@ def entrypoint( _logger: "loguru.Logger", ) -> None: if ( - isinstance(bound_instance.instance, MlxIbvInstance) + isinstance(bound_instance.instance, MlxJacclInstance) and len(bound_instance.instance.ibv_devices) >= 2 ): os.environ["MLX_METAL_FAST_SYNCH"] = "1" diff --git a/src/exo/worker/tests/TODO.tests b/src/exo/worker/tests/TODO.tests index ab667fc3..de72268b 100644 --- a/src/exo/worker/tests/TODO.tests +++ b/src/exo/worker/tests/TODO.tests @@ -37,7 +37,7 @@ Integration tests: 2. Test that node count does not affect inference result (per-configuration) - Llama on 1 node, and on 2 nodes returns the same result, given temperature 0 and set seed. - - Do for all configurations (Ring/Ibv, Pipeline/Tensor) + - Do for all configurations (Ring/Jaccl, Pipeline/Tensor) 3. Test supervisor catches exceptions gracefully - Timeouts