Fix the node-ID test

Co-authored-by: Matt Beton <matthew.beton@gmail.com>
This commit is contained in:
Andrei Cravtov
2025-07-24 17:09:12 +01:00
committed by GitHub
parent df1fe3af26
commit 3730160477
5 changed files with 52 additions and 44 deletions

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import logging
from multiprocessing import Lock
from multiprocessing.synchronize import Lock as LockT
from typing import Optional, TypedDict
import os
from pathlib import Path
from exo_pyo3_bindings import Keypair
from filelock import FileLock
from shared.constants import EXO_NODE_ID_KEYPAIR
@@ -11,41 +13,32 @@ from shared.constants import EXO_NODE_ID_KEYPAIR
This file is responsible for concurrent race-free persistent node-ID retrieval.
"""
class _NodeIdGlobal(TypedDict):
file_lock: LockT
keypair: Optional[Keypair]
_NODE_ID_GLOBAL: _NodeIdGlobal = {
"file_lock": Lock(),
"keypair": None,
}
def _lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
def get_node_id_keypair() -> Keypair:
def get_node_id_keypair(path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR) -> Keypair:
"""
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# get from memory if we have it => read from file otherwise
if _NODE_ID_GLOBAL["keypair"] is not None:
return _NODE_ID_GLOBAL["keypair"]
# operate with cross-process lock to avoid race conditions
with _NODE_ID_GLOBAL["file_lock"]:
with open(EXO_NODE_ID_KEYPAIR, 'a+b') as f: # opens in append-mode => starts at EOF
with FileLock(_lock_path(path)):
with open(path, 'a+b') as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes
f.seek(0) # go to start & read protobuf-encoded bytes
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
_NODE_ID_GLOBAL["keypair"] = Keypair.from_protobuf_encoding(protobuf_encoded)
return _NODE_ID_GLOBAL["keypair"]
except RuntimeError as e: # on runtime error, assume corrupt file
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
except RuntimeError as e: # on runtime error, assume corrupt file
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(EXO_NODE_ID_KEYPAIR, 'w+b') as f:
_NODE_ID_GLOBAL["keypair"] = Keypair.generate_ed25519()
f.write(_NODE_ID_GLOBAL["keypair"].to_protobuf_encoding())
return _NODE_ID_GLOBAL["keypair"]
with open(path, 'w+b') as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -5,6 +5,7 @@ description = "Shared utilities for the Exo project"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"filelock>=3.18.0",
"aiosqlite>=0.20.0",
"networkx>=3.5",
"openai>=1.93.0",

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
import contextlib
import logging
import multiprocessing
import os
from multiprocessing import Event, Process, Queue, Semaphore
from multiprocessing import Event, Queue, Semaphore
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue as QueueT
from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
@@ -14,10 +18,9 @@ from shared.node_id import get_node_id_keypair
NUM_CONCURRENT_PROCS = 10
def _get_keypair_concurrent(num_procs: int) -> bytes:
assert num_procs > 0
def subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
def _get_keypair_concurrent_subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
try:
# synchronise with parent process
logging.info(msg=f"SUBPROCESS {pid}: Started")
sem.release()
@@ -27,9 +30,12 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
logging.info(msg=f"SUBPROCESS {pid}: Reading start")
queue.put(get_node_id_keypair().to_protobuf_encoding())
logging.info(msg=f"SUBPROCESS {pid}: Reading end")
except Exception as e:
logging.error(msg=f"SUBPROCESS {pid}: Error encountered: {e}")
# notify master of finishing
sem.release()
def _get_keypair_concurrent(num_procs: int) -> bytes:
assert num_procs > 0
sem = Semaphore(0)
ev = Event()
@@ -37,8 +43,12 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
# make parent process wait for all subprocesses to start
logging.info(msg=f"PARENT: Starting {num_procs} subprocesses")
ps: list[BaseProcess] = []
for i in range(num_procs):
Process(target=subprocess_task, args=(i + 1, sem, ev, queue)).start()
p = multiprocessing.get_context("fork").Process(target=_get_keypair_concurrent_subprocess_task,
args=(i + 1, sem, ev, queue))
ps.append(p)
p.start()
for _ in range(num_procs):
sem.acquire()
@@ -47,26 +57,30 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
ev.set()
# wait until all subprocesses are done & read results
for _ in range(num_procs):
sem.acquire()
for p in ps:
p.join()
# check that the input/output order match, and that
# all subprocesses end up reading the same file
logging.info(msg="PARENT: Checking consistency")
keypair: Optional[bytes] = None
assert queue.qsize() > 0
while queue.qsize() > 0:
qsize = 0 # cannot use Queue.qsize due to MacOS incompatibility :(
while not queue.empty():
qsize += 1
temp_keypair = queue.get()
if keypair is None:
keypair = temp_keypair
else:
assert keypair == temp_keypair
return keypair # pyright: ignore[reportReturnType]
assert num_procs == qsize
return keypair # pyright: ignore[reportReturnType]
def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
with contextlib.suppress(OSError):
os.remove(p)
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10
@@ -74,7 +88,7 @@ def test_node_id_fetching(caplog: LogCaptureFixture):
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
with caplog.at_level(logging.CRITICAL): # supress logs
with caplog.at_level(logging.CRITICAL): # supress logs
# make sure that continuous fetches return the same value
for _ in range(reps):
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
@@ -82,4 +96,4 @@ def test_node_id_fetching(caplog: LogCaptureFixture):
# make sure that after deleting, we are not fetching the same value
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
for _ in range(reps):
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)

View File

@@ -1,9 +1,7 @@
from typing import Any, Type, TypeVar
T = TypeVar("T")
from typing import Any, Type
def ensure_type(obj: Any, expected_type: Type[T]) -> T: # type: ignore
def ensure_type[T](obj: Any, expected_type: Type[T]) -> T: # type: ignore
if not isinstance(obj, expected_type):
raise TypeError(f"Expected {expected_type}, got {type(obj)}") # type: ignore
return obj

2
uv.lock generated
View File

@@ -269,6 +269,7 @@ version = "0.1.0"
source = { editable = "shared" }
dependencies = [
{ name = "aiosqlite", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -291,6 +292,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "aiosqlite", specifier = ">=0.20.0" },
{ name = "filelock", specifier = ">=3.18.0" },
{ name = "greenlet", specifier = ">=3.2.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "openai", specifier = ">=1.93.0" },