mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
1 Commits
nid-persis
...
consistent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8e5d3710f |
@@ -524,15 +524,15 @@ class API:
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
sharding=instance.sharding(),
|
||||
instance_meta=instance.instance_meta(),
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
@@ -541,8 +541,8 @@ class API:
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -234,6 +234,8 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
@@ -253,6 +255,6 @@ def get_node_id_keypair(
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate()
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -8,12 +8,12 @@ _EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
|
||||
|
||||
|
||||
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
|
||||
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo"""
|
||||
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
|
||||
|
||||
if _EXO_HOME_ENV is not None:
|
||||
return Path.home() / _EXO_HOME_ENV
|
||||
|
||||
if sys.platform != "linux" and env_var != "XDG_CACHE_HOME":
|
||||
if sys.platform != "linux":
|
||||
return Path.home() / ".exo"
|
||||
|
||||
xdg_value = os.environ.get(env_var, None)
|
||||
@@ -54,9 +54,10 @@ DASHBOARD_DIR = (
|
||||
# Log files (data/logs or cache)
|
||||
EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log"
|
||||
EXO_LOG = EXO_LOG_DIR / "exo.log"
|
||||
EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair"
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
|
||||
@@ -94,27 +94,7 @@ def test_macos_uses_traditional_paths():
|
||||
home = Path.home()
|
||||
assert home / ".exo" == constants.EXO_CONFIG_HOME
|
||||
assert home / ".exo" == constants.EXO_DATA_HOME
|
||||
assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME
|
||||
|
||||
|
||||
def test_exo_home_env():
|
||||
"""Test that macOS uses traditional ~/.exo directory."""
|
||||
# Remove EXO_HOME to ensure we test the default behavior
|
||||
env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"}
|
||||
env["EXO_HOME"] = "/exo"
|
||||
with (
|
||||
mock.patch.dict(os.environ, env, clear=True),
|
||||
mock.patch.object(sys, "platform", "darwin"),
|
||||
):
|
||||
import importlib
|
||||
|
||||
import exo.shared.constants as constants
|
||||
|
||||
importlib.reload(constants)
|
||||
|
||||
assert Path("/exo") == constants.EXO_CONFIG_HOME
|
||||
assert Path("/exo") == constants.EXO_DATA_HOME
|
||||
assert Path("/exo") == constants.EXO_CACHE_HOME
|
||||
assert home / ".exo" == constants.EXO_CACHE_HOME
|
||||
|
||||
|
||||
def test_node_id_in_config_dir():
|
||||
|
||||
@@ -4,7 +4,13 @@ from pydantic import model_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelTask
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -24,16 +30,40 @@ class BaseInstance(TaggedModel):
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta: ...
|
||||
|
||||
def sharding(self) -> Sharding:
|
||||
if all(
|
||||
isinstance(sm, PipelineShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Pipeline
|
||||
if all(
|
||||
isinstance(sm, TensorShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Tensor
|
||||
raise ValueError("shard metadata malformed")
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxRing
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxJaccl
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
|
||||
Reference in New Issue
Block a user