mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
1 Commits
ciaran/use
...
nid-persis
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f2a49096b |
@@ -234,8 +234,6 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
@@ -255,6 +253,6 @@ def get_node_id_keypair(
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
keypair = Keypair.generate()
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -8,12 +8,12 @@ _EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
|
||||
|
||||
|
||||
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
|
||||
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
|
||||
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo"""
|
||||
|
||||
if _EXO_HOME_ENV is not None:
|
||||
return Path.home() / _EXO_HOME_ENV
|
||||
|
||||
if sys.platform != "linux":
|
||||
if sys.platform != "linux" and env_var != "XDG_CACHE_HOME":
|
||||
return Path.home() / ".exo"
|
||||
|
||||
xdg_value = os.environ.get(env_var, None)
|
||||
@@ -54,10 +54,9 @@ DASHBOARD_DIR = (
|
||||
# Log files (data/logs or cache)
|
||||
EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log"
|
||||
EXO_LOG = EXO_LOG_DIR / "exo.log"
|
||||
EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
|
||||
@@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths():
|
||||
home = Path.home()
|
||||
assert home / ".exo" == constants.EXO_CONFIG_HOME
|
||||
assert home / ".exo" == constants.EXO_DATA_HOME
|
||||
assert home / ".exo" == constants.EXO_CACHE_HOME
|
||||
assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME
|
||||
|
||||
|
||||
def test_exo_home_env():
|
||||
"""Test that macOS uses traditional ~/.exo directory."""
|
||||
# Remove EXO_HOME to ensure we test the default behavior
|
||||
env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"}
|
||||
env["EXO_HOME"] = "/exo"
|
||||
with (
|
||||
mock.patch.dict(os.environ, env, clear=True),
|
||||
mock.patch.object(sys, "platform", "darwin"),
|
||||
):
|
||||
import importlib
|
||||
|
||||
import exo.shared.constants as constants
|
||||
|
||||
importlib.reload(constants)
|
||||
|
||||
assert Path("/exo") == constants.EXO_CONFIG_HOME
|
||||
assert Path("/exo") == constants.EXO_DATA_HOME
|
||||
assert Path("/exo") == constants.EXO_CACHE_HOME
|
||||
|
||||
|
||||
def test_node_id_in_config_dir():
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
@@ -99,13 +98,14 @@ def mlx_distributed_init(
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
coordination_file = str(
|
||||
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
coordination_file = None
|
||||
try:
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
@@ -128,6 +128,9 @@ def mlx_distributed_init(
|
||||
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
|
||||
)
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
@@ -147,6 +150,10 @@ def mlx_distributed_init(
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
|
||||
Reference in New Issue
Block a user