rename ibv to jaccl inline with mlx

This commit is contained in:
Evan Quiney
2025-12-05 16:42:43 +00:00
committed by GitHub
parent f5783d6455
commit 9e0a1c23ef
6 changed files with 14 additions and 14 deletions

View File

@@ -26,7 +26,7 @@ from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxIbvInstance,
MlxJacclInstance,
MlxRingInstance,
)
@@ -105,7 +105,7 @@ def get_instance_placements_after_create(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxIbv:
case InstanceMeta.MlxJaccl:
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
@@ -114,7 +114,7 @@ def get_instance_placements_after_create(
selected_cycle,
coordinator_port=random_ephemeral_port(),
)
target_instances[instance_id] = MlxIbvInstance(
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,

View File

@@ -19,7 +19,7 @@ from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxIbvInstance,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import ShardAssignments
@@ -422,7 +422,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
cic = CreateInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxIbv,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_meta=model_meta,
min_nodes=1,
@@ -434,7 +434,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert isinstance(instance, MlxIbvInstance)
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.ibv_coordinator is not None

View File

@@ -13,7 +13,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxIbv = "MlxIbv"
MlxJaccl = "MlxJaccl"
class BaseInstance(TaggedModel):
@@ -28,13 +28,13 @@ class MlxRingInstance(BaseInstance):
hosts: list[Host]
class MlxIbvInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
ibv_coordinator: str
# TODO: Single node instance
Instance = MlxRingInstance | MlxIbvInstance
Instance = MlxRingInstance | MlxJacclInstance
class BoundInstance(CamelCaseModel):

View File

@@ -32,7 +32,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
MlxIbvInstance,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import (
@@ -128,7 +128,7 @@ def mlx_distributed_init(
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
case MlxIbvInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
case MlxJacclInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)

View File

@@ -4,7 +4,7 @@ import loguru
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxIbvInstance
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.utils.channels import MpReceiver, MpSender
logger: "loguru.Logger"
@@ -21,7 +21,7 @@ def entrypoint(
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxIbvInstance)
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"

View File

@@ -37,7 +37,7 @@ Integration tests:
2. Test that node count does not affect inference result (per-configuration)
- Llama on 1 node, and on 2 nodes returns the same result, given temperature 0 and set seed.
- Do for all configurations (Ring/Ibv, Pipeline/Tensor)
- Do for all configurations (Ring/Jaccl, Pipeline/Tensor)
3. Test supervisor catches exceptions gracefully
- Timeouts