Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
3f2a49096b persist node ids in .cache
brings back EXO_CACHE_HOME as always ~/.cache/exo/, and store the node
id in there. no random copies now!
2026-02-25 18:01:39 +00:00
4 changed files with 37 additions and 13 deletions

View File

@@ -234,8 +234,6 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -255,6 +253,6 @@ def get_node_id_keypair(
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
keypair = Keypair.generate()
f.write(keypair.to_bytes())
return keypair

View File

@@ -8,12 +8,12 @@ _EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo"""
if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV
if sys.platform != "linux":
if sys.platform != "linux" and env_var != "XDG_CACHE_HOME":
return Path.home() / ".exo"
xdg_value = os.environ.get(env_var, None)
@@ -54,10 +54,9 @@ DASHBOARD_DIR = (
# Log files (data/logs or cache)
EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log"
EXO_LOG = EXO_LOG_DIR / "exo.log"
EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding

View File

@@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths():
home = Path.home()
assert home / ".exo" == constants.EXO_CONFIG_HOME
assert home / ".exo" == constants.EXO_DATA_HOME
assert home / ".exo" == constants.EXO_CACHE_HOME
assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME
def test_exo_home_env():
"""Test that macOS uses traditional ~/.exo directory."""
# Remove EXO_HOME to ensure we test the default behavior
env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"}
env["EXO_HOME"] = "/exo"
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch.object(sys, "platform", "darwin"),
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
assert Path("/exo") == constants.EXO_CONFIG_HOME
assert Path("/exo") == constants.EXO_DATA_HOME
assert Path("/exo") == constants.EXO_CACHE_HOME
def test_node_id_in_config_dir():

View File

@@ -2,7 +2,6 @@ import json
import os
import re
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, cast
@@ -99,13 +98,14 @@ def mlx_distributed_init(
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
with tempfile.TemporaryDirectory() as tmpdir:
coordination_file = str(
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
coordination_file = None
try:
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
@@ -128,6 +128,9 @@ def mlx_distributed_init(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(coordination_file, "w") as f:
@@ -147,6 +150,10 @@ def mlx_distributed_init(
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx(