From 40a0d47de8e2f5ffd60690b32b14a95a7c09420d Mon Sep 17 00:00:00 2001 From: Evan Date: Wed, 3 Dec 2025 13:47:05 +0000 Subject: [PATCH] jaccl --- src/exo/worker/engines/mlx/utils_mlx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index c9f47449..c0540a9d 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -3,6 +3,8 @@ import resource import time from pathlib import Path from typing import Any, Callable, cast +import json + from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache from mlx_lm.models.deepseek_v3 import DeepseekV3Model @@ -128,8 +130,6 @@ def mlx_distributed_init( group = mx.distributed.init(backend="ring", strict=True) case MlxIbvInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator): - import json - # Use RDMA connectivity matrix devices_file = f"./hosts_{rank}.json" ibv_devices_json = json.dumps(ibv_devices) @@ -142,7 +142,7 @@ def mlx_distributed_init( os.environ["MLX_IBV_DEVICES"] = devices_file os.environ["MLX_RANK"] = str(rank) os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator - group = mx.distributed.init(backend="ibv", strict=True) + group = mx.distributed.init(backend="jaccl", strict=True) logger.info(f"Rank {rank} mlx distributed initialization complete")