diff --git a/shared/node_id.py b/shared/node_id.py index 564a87a2..3d7942f4 100644 --- a/shared/node_id.py +++ b/shared/node_id.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import logging -from multiprocessing import Lock -from multiprocessing.synchronize import Lock as LockT -from typing import Optional, TypedDict +import os +from pathlib import Path from exo_pyo3_bindings import Keypair +from filelock import FileLock from shared.constants import EXO_NODE_ID_KEYPAIR @@ -11,41 +13,32 @@ from shared.constants import EXO_NODE_ID_KEYPAIR This file is responsible for concurrent race-free persistent node-ID retrieval. """ -class _NodeIdGlobal(TypedDict): - file_lock: LockT - keypair: Optional[Keypair] -_NODE_ID_GLOBAL: _NodeIdGlobal = { - "file_lock": Lock(), - "keypair": None, -} +def _lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path: + return Path(str(path) + ".lock") -def get_node_id_keypair() -> Keypair: + +def get_node_id_keypair(path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR) -> Keypair: """ Obtains the :class:`Keypair` associated with this node-ID. Obtain the :class:`PeerId` by from it. """ - # get from memory if we have it => read from file otherwise - if _NODE_ID_GLOBAL["keypair"] is not None: - return _NODE_ID_GLOBAL["keypair"] - # operate with cross-process lock to avoid race conditions - with _NODE_ID_GLOBAL["file_lock"]: - with open(EXO_NODE_ID_KEYPAIR, 'a+b') as f: # opens in append-mode => starts at EOF + with FileLock(_lock_path(path)): + with open(path, 'a+b') as f: # opens in append-mode => starts at EOF # if non-zero EOF, then file exists => use to get node-ID if f.tell() != 0: - f.seek(0) # go to start & read protobuf-encoded bytes + f.seek(0) # go to start & read protobuf-encoded bytes protobuf_encoded = f.read() - try: # if decoded successfully, save & return - _NODE_ID_GLOBAL["keypair"] = Keypair.from_protobuf_encoding(protobuf_encoded) - return _NODE_ID_GLOBAL["keypair"] - except RuntimeError as e: # on runtime error, assume corrupt file + try: # if decoded successfully, save & return + return Keypair.from_protobuf_encoding(protobuf_encoded) + except RuntimeError as e: # on runtime error, assume corrupt file logging.warning(f"Encountered runtime error when trying to get keypair: {e}") # if no valid credentials, create new ones and persist - with open(EXO_NODE_ID_KEYPAIR, 'w+b') as f: - _NODE_ID_GLOBAL["keypair"] = Keypair.generate_ed25519() - f.write(_NODE_ID_GLOBAL["keypair"].to_protobuf_encoding()) - return _NODE_ID_GLOBAL["keypair"] \ No newline at end of file + with open(path, 'w+b') as f: + keypair = Keypair.generate_ed25519() + f.write(keypair.to_protobuf_encoding()) + return keypair diff --git a/shared/pyproject.toml b/shared/pyproject.toml index 78920a59..05d3ff74 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -5,6 +5,7 @@ description = "Shared utilities for the Exo project" readme = "README.md" requires-python = ">=3.13" dependencies = [ + "filelock>=3.18.0", "aiosqlite>=0.20.0", "networkx>=3.5", "openai>=1.93.0", diff --git a/shared/tests/test_node_id_persistence.py b/shared/tests/test_node_id_persistence.py index 6f030b74..44943f49 100644 --- a/shared/tests/test_node_id_persistence.py +++ b/shared/tests/test_node_id_persistence.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import contextlib import logging +import multiprocessing import os -from multiprocessing import Event, Process, Queue, Semaphore +from multiprocessing import Event, Queue, Semaphore +from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue as QueueT from multiprocessing.synchronize import Event as EventT from multiprocessing.synchronize import Semaphore as SemaphoreT @@ -14,10 +18,9 @@ from shared.node_id import get_node_id_keypair NUM_CONCURRENT_PROCS = 10 -def _get_keypair_concurrent(num_procs: int) -> bytes: - assert num_procs > 0 - def subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None: +def _get_keypair_concurrent_subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None: + try: # synchronise with parent process logging.info(msg=f"SUBPROCESS {pid}: Started") sem.release() @@ -27,9 +30,12 @@ def _get_keypair_concurrent(num_procs: int) -> bytes: logging.info(msg=f"SUBPROCESS {pid}: Reading start") queue.put(get_node_id_keypair().to_protobuf_encoding()) logging.info(msg=f"SUBPROCESS {pid}: Reading end") + except Exception as e: + logging.error(msg=f"SUBPROCESS {pid}: Error encountered: {e}") - # notify master of finishing - sem.release() + +def _get_keypair_concurrent(num_procs: int) -> bytes: + assert num_procs > 0 sem = Semaphore(0) ev = Event() @@ -37,8 +43,12 @@ def _get_keypair_concurrent(num_procs: int) -> bytes: # make parent process wait for all subprocesses to start logging.info(msg=f"PARENT: Starting {num_procs} subprocesses") + ps: list[BaseProcess] = [] for i in range(num_procs): - Process(target=subprocess_task, args=(i + 1, sem, ev, queue)).start() + p = multiprocessing.get_context("fork").Process(target=_get_keypair_concurrent_subprocess_task, + args=(i + 1, sem, ev, queue)) + ps.append(p) + p.start() for _ in range(num_procs): sem.acquire() @@ -47,26 +57,30 @@ def _get_keypair_concurrent(num_procs: int) -> bytes: ev.set() # wait until all subprocesses are done & read results - for _ in range(num_procs): - sem.acquire() + for p in ps: + p.join() # check that the input/output order match, and that # all subprocesses end up reading the same file logging.info(msg="PARENT: Checking consistency") keypair: Optional[bytes] = None - assert queue.qsize() > 0 - while queue.qsize() > 0: + qsize = 0 # cannot use Queue.qsize due to MacOS incompatibility :( + while not queue.empty(): + qsize += 1 temp_keypair = queue.get() if keypair is None: keypair = temp_keypair else: assert keypair == temp_keypair - return keypair # pyright: ignore[reportReturnType] + assert num_procs == qsize + return keypair # pyright: ignore[reportReturnType] + def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]): with contextlib.suppress(OSError): os.remove(p) + def test_node_id_fetching(caplog: LogCaptureFixture): reps = 10 @@ -74,7 +88,7 @@ def test_node_id_fetching(caplog: LogCaptureFixture): _delete_if_exists(EXO_NODE_ID_KEYPAIR) kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS) - with caplog.at_level(logging.CRITICAL): # supress logs + with caplog.at_level(logging.CRITICAL): # supress logs # make sure that continuous fetches return the same value for _ in range(reps): assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS) @@ -82,4 +96,4 @@ def test_node_id_fetching(caplog: LogCaptureFixture): # make sure that after deleting, we are not fetching the same value _delete_if_exists(EXO_NODE_ID_KEYPAIR) for _ in range(reps): - assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS) \ No newline at end of file + assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS) diff --git a/shared/utils.py b/shared/utils.py index bf2be769..974091eb 100644 --- a/shared/utils.py +++ b/shared/utils.py @@ -1,9 +1,7 @@ -from typing import Any, Type, TypeVar - -T = TypeVar("T") +from typing import Any, Type -def ensure_type(obj: Any, expected_type: Type[T]) -> T: # type: ignore +def ensure_type[T](obj: Any, expected_type: Type[T]) -> T: # type: ignore if not isinstance(obj, expected_type): raise TypeError(f"Expected {expected_type}, got {type(obj)}") # type: ignore return obj diff --git a/uv.lock b/uv.lock index 3c541f99..d771b989 100644 --- a/uv.lock +++ b/uv.lock @@ -269,6 +269,7 @@ version = "0.1.0" source = { editable = "shared" } dependencies = [ { name = "aiosqlite", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -291,6 +292,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, + { name = "filelock", specifier = ">=3.18.0" }, { name = "greenlet", specifier = ">=3.2.3" }, { name = "networkx", specifier = ">=3.5" }, { name = "openai", specifier = ">=1.93.0" },