mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 14:17:58 -05:00
jaccl
This commit is contained in:
@@ -3,6 +3,8 @@ import resource
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast
|
||||
import json
|
||||
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
@@ -128,8 +130,6 @@ def mlx_distributed_init(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxIbvInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
|
||||
import json
|
||||
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
@@ -142,7 +142,7 @@ def mlx_distributed_init(
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
|
||||
group = mx.distributed.init(backend="ibv", strict=True)
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user