Compare commits

...

2 Commits

Author SHA1 Message Date
Evan
c3bf082ab9 persist node ids in .cache
brings back EXO_CACHE_HOME as always ~/.cache/exo/, and store the node
id in there. no random copies now!
2026-02-26 13:24:58 +00:00
ciaranbor
eaed92952c Use tmpdir for coordination file (#1624)
## Motivation

Coordination files for MLX distributed init were written to the current
working directory (./hosts_*.json)

## Changes

- Move coordination file creation to a tempfile.TemporaryDirectory(),
which auto-cleans on context manager exit
2026-02-26 10:59:36 +00:00
4 changed files with 30 additions and 20 deletions

View File

@@ -234,8 +234,6 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -255,6 +253,6 @@ def get_node_id_keypair(
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
keypair = Keypair.generate()
f.write(keypair.to_bytes())
return keypair

View File

@@ -8,12 +8,12 @@ _EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo"""
if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV
if sys.platform != "linux":
if sys.platform != "linux" and env_var != "XDG_CACHE_HOME":
return Path.home() / ".exo"
xdg_value = os.environ.get(env_var, None)
@@ -54,10 +54,9 @@ DASHBOARD_DIR = (
# Log files (data/logs or cache)
EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log"
EXO_LOG = EXO_LOG_DIR / "exo.log"
EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding

View File

@@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths():
home = Path.home()
assert home / ".exo" == constants.EXO_CONFIG_HOME
assert home / ".exo" == constants.EXO_DATA_HOME
assert home / ".exo" == constants.EXO_CACHE_HOME
assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME
def test_exo_home_env():
"""Test that macOS uses traditional ~/.exo directory."""
# Remove EXO_HOME to ensure we test the default behavior
env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"}
env["EXO_HOME"] = "/exo"
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch.object(sys, "platform", "darwin"),
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
assert Path("/exo") == constants.EXO_CONFIG_HOME
assert Path("/exo") == constants.EXO_DATA_HOME
assert Path("/exo") == constants.EXO_CACHE_HOME
def test_node_id_in_config_dir():

View File

@@ -2,6 +2,7 @@ import json
import os
import re
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, cast
@@ -98,14 +99,13 @@ def mlx_distributed_init(
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
coordination_file = None
try:
with tempfile.TemporaryDirectory() as tmpdir:
coordination_file = str(
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
@@ -128,9 +128,6 @@ def mlx_distributed_init(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(coordination_file, "w") as f:
@@ -150,10 +147,6 @@ def mlx_distributed_init(
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx(