Compare commits

..

1 Commits

Author SHA1 Message Date
ciaranbor
6f003759c4 Use tmpdir for coordination file 2026-02-25 18:57:57 +00:00
3 changed files with 17 additions and 12 deletions

View File

@@ -27,6 +27,7 @@ dependencies = [
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux==0.15.5",
"python-multipart>=0.0.21",
"msgspec>=0.19.0",
"zstandard>=0.23.0",
]

View File

@@ -2,6 +2,7 @@ import json
import os
import re
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, cast
@@ -98,14 +99,13 @@ def mlx_distributed_init(
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
coordination_file = None
try:
with tempfile.TemporaryDirectory() as tmpdir:
coordination_file = str(
Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
@@ -128,9 +128,6 @@ def mlx_distributed_init(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(coordination_file, "w") as f:
@@ -150,10 +147,6 @@ def mlx_distributed_init(
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx(

11
uv.lock generated
View File

@@ -385,6 +385,7 @@ dependencies = [
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "python-multipart", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -423,6 +424,7 @@ requires-dist = [
{ name = "pillow", specifier = ">=11.0,<12.0" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "python-multipart", specifier = ">=0.0.21" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
@@ -1882,6 +1884,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
]
[[package]]
name = "python-multipart"
version = "0.0.21"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/78/96/804520d0850c7db98e5ccb70282e29208723f0964e88ffd9d0da2f52ea09/python_multipart-0.0.21.tar.gz", hash = "sha256:7137ebd4d3bbf70ea1622998f902b97a29434a9e8dc40eb203bbcf7c2a2cba92", size = 37196, upload-time = "2025-12-17T09:24:22.446Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/aa/76/03af049af4dcee5d27442f71b6924f01f3efb5d2bd34f23fcd563f2cc5f5/python_multipart-0.0.21-py3-none-any.whl", hash = "sha256:cf7a6713e01c87aa35387f4774e812c4361150938d20d232800f75ffcf266090", size = 24541, upload-time = "2025-12-17T09:24:21.153Z" },
]
[[package]]
name = "pyyaml"
version = "6.0.3"