diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index 63ff8526..9683689b 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -1,35 +1,46 @@ import os +import sys from pathlib import Path -EXO_HOME_RELATIVE_PATH = os.environ.get("EXO_HOME", ".exo") -EXO_HOME = Path.home() / EXO_HOME_RELATIVE_PATH +_EXO_HOME_ENV = os.environ.get("EXO_HOME", None) -EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR") -EXO_MODELS_DIR = Path(EXO_MODELS_DIR_ENV) if EXO_MODELS_DIR_ENV else EXO_HOME / "models" -EXO_GLOBAL_EVENT_DB = EXO_HOME / "global_events.db" -EXO_WORKER_EVENT_DB = EXO_HOME / "worker_events.db" -EXO_MASTER_STATE = EXO_HOME / "master_state.json" -EXO_WORKER_STATE = EXO_HOME / "worker_state.json" -EXO_MASTER_LOG = EXO_HOME / "master.log" -EXO_WORKER_LOG = EXO_HOME / "worker.log" -EXO_LOG = EXO_HOME / "exo.log" -EXO_TEST_LOG = EXO_HOME / "exo_test.log" +def _get_xdg_dir(env_var: str, fallback: str) -> Path: + """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo.""" -EXO_NODE_ID_KEYPAIR = EXO_HOME / "node_id.keypair" + if _EXO_HOME_ENV is not None: + return Path.home() / _EXO_HOME_ENV -EXO_WORKER_KEYRING_FILE = EXO_HOME / "worker_keyring" -EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring" + if sys.platform != "linux": + return Path.home() / ".exo" -EXO_IPC_DIR = EXO_HOME / "ipc" + xdg_value = os.environ.get(env_var, None) + if xdg_value is not None: + return Path(xdg_value) / "exo" + return Path.home() / fallback / "exo" + + +EXO_CONFIG_HOME = _get_xdg_dir("XDG_CONFIG_HOME", ".config") +EXO_DATA_HOME = _get_xdg_dir("XDG_DATA_HOME", ".local/share") +EXO_CACHE_HOME = _get_xdg_dir("XDG_CACHE_HOME", ".cache") + +# Models directory (data) +_EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None) +EXO_MODELS_DIR = ( + EXO_DATA_HOME / "models" + if _EXO_MODELS_DIR_ENV is None + else Path.home() / _EXO_MODELS_DIR_ENV +) + +# Log files (data/logs or cache) +EXO_LOG = EXO_CACHE_HOME / "exo.log" +EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log" + +# Identity (config) +EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair" # libp2p topics for event forwarding LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events" LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events" LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message" LIBP2P_COMMANDS_TOPIC = "commands" - -# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip. -LB_TFLOPS = 2.3 -LB_MEMBW_GBPS = 68 -LB_DISK_GBPS = 1.5 diff --git a/src/exo/shared/tests/test_xdg_paths.py b/src/exo/shared/tests/test_xdg_paths.py new file mode 100644 index 00000000..73f81d10 --- /dev/null +++ b/src/exo/shared/tests/test_xdg_paths.py @@ -0,0 +1,118 @@ +"""Tests for XDG Base Directory Specification compliance.""" + +import os +import sys +from pathlib import Path +from unittest import mock + + +def test_xdg_paths_on_linux(): + """Test that XDG paths are used on Linux when XDG env vars are set.""" + with ( + mock.patch.dict( + os.environ, + { + "XDG_CONFIG_HOME": "/tmp/test-config", + "XDG_DATA_HOME": "/tmp/test-data", + "XDG_CACHE_HOME": "/tmp/test-cache", + }, + clear=False, + ), + mock.patch.object(sys, "platform", "linux"), + ): + # Re-import to pick up mocked values + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + assert Path("/tmp/test-config/exo") == constants.EXO_CONFIG_HOME + assert Path("/tmp/test-data/exo") == constants.EXO_DATA_HOME + assert Path("/tmp/test-cache/exo") == constants.EXO_CACHE_HOME + + +def test_xdg_default_paths_on_linux(): + """Test that XDG default paths are used on Linux when env vars are not set.""" + # Remove XDG env vars and EXO_HOME + env = { + k: v + for k, v in os.environ.items() + if not k.startswith("XDG_") and k != "EXO_HOME" + } + with ( + mock.patch.dict(os.environ, env, clear=True), + mock.patch.object(sys, "platform", "linux"), + ): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + home = Path.home() + assert home / ".config" / "exo" == constants.EXO_CONFIG_HOME + assert home / ".local/share" / "exo" == constants.EXO_DATA_HOME + assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME + + +def test_legacy_exo_home_takes_precedence(): + """Test that EXO_HOME environment variable takes precedence for backward compatibility.""" + with mock.patch.dict( + os.environ, + { + "EXO_HOME": ".custom-exo", + "XDG_CONFIG_HOME": "/tmp/test-config", + }, + clear=False, + ): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + home = Path.home() + assert home / ".custom-exo" == constants.EXO_CONFIG_HOME + assert home / ".custom-exo" == constants.EXO_DATA_HOME + + +def test_macos_uses_traditional_paths(): + """Test that macOS uses traditional ~/.exo directory.""" + # Remove EXO_HOME to ensure we test the default behavior + env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"} + with ( + mock.patch.dict(os.environ, env, clear=True), + mock.patch.object(sys, "platform", "darwin"), + ): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + home = Path.home() + assert home / ".exo" == constants.EXO_CONFIG_HOME + assert home / ".exo" == constants.EXO_DATA_HOME + assert home / ".exo" == constants.EXO_CACHE_HOME + + +def test_node_id_in_config_dir(): + """Test that node ID keypair is in the config directory.""" + import exo.shared.constants as constants + + assert constants.EXO_NODE_ID_KEYPAIR.parent == constants.EXO_CONFIG_HOME + + +def test_models_in_data_dir(): + """Test that models directory is in the data directory.""" + # Clear EXO_MODELS_DIR to test default behavior + env = {k: v for k, v in os.environ.items() if k != "EXO_MODELS_DIR"} + with mock.patch.dict(os.environ, env, clear=True): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + assert constants.EXO_MODELS_DIR.parent == constants.EXO_DATA_HOME diff --git a/src/exo/worker/download/download_utils.py b/src/exo/worker/download/download_utils.py index 97b4e00e..8d033471 100644 --- a/src/exo/worker/download/download_utils.py +++ b/src/exo/worker/download/download_utils.py @@ -24,7 +24,7 @@ from pydantic import ( TypeAdapter, ) -from exo.shared.constants import EXO_HOME, EXO_MODELS_DIR +from exo.shared.constants import EXO_MODELS_DIR from exo.shared.types.memory import Memory from exo.shared.types.worker.downloads import DownloadProgressData from exo.shared.types.worker.shards import ShardMetadata @@ -132,25 +132,6 @@ async def resolve_model_path_for_repo(repo_id: str) -> Path: return (await ensure_models_dir()) / repo_id.replace("/", "--") -async def ensure_exo_home() -> Path: - await aios.makedirs(EXO_HOME, exist_ok=True) - return EXO_HOME - - -async def has_exo_home_read_access() -> bool: - try: - return await aios.access(EXO_HOME, os.R_OK) - except OSError: - return False - - -async def has_exo_home_write_access() -> bool: - try: - return await aios.access(EXO_HOME, os.W_OK) - except OSError: - return False - - async def ensure_models_dir() -> Path: await aios.makedirs(EXO_MODELS_DIR, exist_ok=True) return EXO_MODELS_DIR