mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
1 Commits
main
...
ciaran/use
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f003759c4 |
@@ -542,10 +542,13 @@ class InfoGatherer:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
t = TextReceiveStream(BufferedByteReceiveStream(p.stdout))
|
||||
while True:
|
||||
with anyio.fail_after(self.macmon_interval * 3):
|
||||
macmon_output = await t.receive()
|
||||
await self.info_sender.send(
|
||||
MacmonMetrics.from_raw_json(macmon_output)
|
||||
)
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
@@ -556,8 +559,12 @@ class InfoGatherer:
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
f"memory monitor failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"memory monitor silent for {self.macmon_interval * 3}s - reloading"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in macmon monitor: {e}")
|
||||
logger.opt(exception=e).warning("Error in memory monitor")
|
||||
await anyio.sleep(self.macmon_interval)
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
@@ -98,14 +99,13 @@ def mlx_distributed_init(
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
coordination_file = None
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
coordination_file = str(
|
||||
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
@@ -128,9 +128,6 @@ def mlx_distributed_init(
|
||||
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
|
||||
)
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
@@ -150,10 +147,6 @@ def mlx_distributed_init(
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
|
||||
Reference in New Issue
Block a user