This commit is contained in:
Evan
2025-12-03 13:47:05 +00:00
parent 2b243bd80e
commit 40a0d47de8

View File

@@ -3,6 +3,8 @@ import resource
import time
from pathlib import Path
from typing import Any, Callable, cast
import json
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
@@ -128,8 +130,6 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxIbvInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
import json
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)
@@ -142,7 +142,7 @@ def mlx_distributed_init(
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
group = mx.distributed.init(backend="ibv", strict=True)
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")