mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Fix the node-ID test
Co-authored-by: Matt Beton <matthew.beton@gmail.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing.synchronize import Lock as LockT
|
||||
from typing import Optional, TypedDict
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from exo_pyo3_bindings import Keypair
|
||||
from filelock import FileLock
|
||||
|
||||
from shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
|
||||
@@ -11,41 +13,32 @@ from shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
This file is responsible for concurrent race-free persistent node-ID retrieval.
|
||||
"""
|
||||
|
||||
class _NodeIdGlobal(TypedDict):
|
||||
file_lock: LockT
|
||||
keypair: Optional[Keypair]
|
||||
|
||||
_NODE_ID_GLOBAL: _NodeIdGlobal = {
|
||||
"file_lock": Lock(),
|
||||
"keypair": None,
|
||||
}
|
||||
def _lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
def get_node_id_keypair() -> Keypair:
|
||||
|
||||
def get_node_id_keypair(path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR) -> Keypair:
|
||||
"""
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
|
||||
# get from memory if we have it => read from file otherwise
|
||||
if _NODE_ID_GLOBAL["keypair"] is not None:
|
||||
return _NODE_ID_GLOBAL["keypair"]
|
||||
|
||||
# operate with cross-process lock to avoid race conditions
|
||||
with _NODE_ID_GLOBAL["file_lock"]:
|
||||
with open(EXO_NODE_ID_KEYPAIR, 'a+b') as f: # opens in append-mode => starts at EOF
|
||||
with FileLock(_lock_path(path)):
|
||||
with open(path, 'a+b') as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
_NODE_ID_GLOBAL["keypair"] = Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return _NODE_ID_GLOBAL["keypair"]
|
||||
except RuntimeError as e: # on runtime error, assume corrupt file
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
except RuntimeError as e: # on runtime error, assume corrupt file
|
||||
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(EXO_NODE_ID_KEYPAIR, 'w+b') as f:
|
||||
_NODE_ID_GLOBAL["keypair"] = Keypair.generate_ed25519()
|
||||
f.write(_NODE_ID_GLOBAL["keypair"].to_protobuf_encoding())
|
||||
return _NODE_ID_GLOBAL["keypair"]
|
||||
with open(path, 'w+b') as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
return keypair
|
||||
|
||||
@@ -5,6 +5,7 @@ description = "Shared utilities for the Exo project"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"filelock>=3.18.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"networkx>=3.5",
|
||||
"openai>=1.93.0",
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from multiprocessing import Event, Process, Queue, Semaphore
|
||||
from multiprocessing import Event, Queue, Semaphore
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue as QueueT
|
||||
from multiprocessing.synchronize import Event as EventT
|
||||
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
@@ -14,10 +18,9 @@ from shared.node_id import get_node_id_keypair
|
||||
|
||||
NUM_CONCURRENT_PROCS = 10
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
assert num_procs > 0
|
||||
|
||||
def subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
|
||||
def _get_keypair_concurrent_subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
|
||||
try:
|
||||
# synchronise with parent process
|
||||
logging.info(msg=f"SUBPROCESS {pid}: Started")
|
||||
sem.release()
|
||||
@@ -27,9 +30,12 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
logging.info(msg=f"SUBPROCESS {pid}: Reading start")
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
logging.info(msg=f"SUBPROCESS {pid}: Reading end")
|
||||
except Exception as e:
|
||||
logging.error(msg=f"SUBPROCESS {pid}: Error encountered: {e}")
|
||||
|
||||
# notify master of finishing
|
||||
sem.release()
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
assert num_procs > 0
|
||||
|
||||
sem = Semaphore(0)
|
||||
ev = Event()
|
||||
@@ -37,8 +43,12 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
# make parent process wait for all subprocesses to start
|
||||
logging.info(msg=f"PARENT: Starting {num_procs} subprocesses")
|
||||
ps: list[BaseProcess] = []
|
||||
for i in range(num_procs):
|
||||
Process(target=subprocess_task, args=(i + 1, sem, ev, queue)).start()
|
||||
p = multiprocessing.get_context("fork").Process(target=_get_keypair_concurrent_subprocess_task,
|
||||
args=(i + 1, sem, ev, queue))
|
||||
ps.append(p)
|
||||
p.start()
|
||||
for _ in range(num_procs):
|
||||
sem.acquire()
|
||||
|
||||
@@ -47,26 +57,30 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
ev.set()
|
||||
|
||||
# wait until all subprocesses are done & read results
|
||||
for _ in range(num_procs):
|
||||
sem.acquire()
|
||||
for p in ps:
|
||||
p.join()
|
||||
|
||||
# check that the input/output order match, and that
|
||||
# all subprocesses end up reading the same file
|
||||
logging.info(msg="PARENT: Checking consistency")
|
||||
keypair: Optional[bytes] = None
|
||||
assert queue.qsize() > 0
|
||||
while queue.qsize() > 0:
|
||||
qsize = 0 # cannot use Queue.qsize due to MacOS incompatibility :(
|
||||
while not queue.empty():
|
||||
qsize += 1
|
||||
temp_keypair = queue.get()
|
||||
if keypair is None:
|
||||
keypair = temp_keypair
|
||||
else:
|
||||
assert keypair == temp_keypair
|
||||
return keypair # pyright: ignore[reportReturnType]
|
||||
assert num_procs == qsize
|
||||
return keypair # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(p)
|
||||
|
||||
|
||||
def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
@@ -74,7 +88,7 @@ def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
||||
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
|
||||
with caplog.at_level(logging.CRITICAL): # supress logs
|
||||
with caplog.at_level(logging.CRITICAL): # supress logs
|
||||
# make sure that continuous fetches return the same value
|
||||
for _ in range(reps):
|
||||
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
@@ -82,4 +96,4 @@ def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
# make sure that after deleting, we are not fetching the same value
|
||||
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
||||
for _ in range(reps):
|
||||
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
from typing import Any, Type
|
||||
|
||||
|
||||
def ensure_type(obj: Any, expected_type: Type[T]) -> T: # type: ignore
|
||||
def ensure_type[T](obj: Any, expected_type: Type[T]) -> T: # type: ignore
|
||||
if not isinstance(obj, expected_type):
|
||||
raise TypeError(f"Expected {expected_type}, got {type(obj)}") # type: ignore
|
||||
return obj
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -269,6 +269,7 @@ version = "0.1.0"
|
||||
source = { editable = "shared" }
|
||||
dependencies = [
|
||||
{ name = "aiosqlite", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -291,6 +292,7 @@ dev = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiosqlite", specifier = ">=0.20.0" },
|
||||
{ name = "filelock", specifier = ">=3.18.0" },
|
||||
{ name = "greenlet", specifier = ">=3.2.3" },
|
||||
{ name = "networkx", specifier = ">=3.5" },
|
||||
{ name = "openai", specifier = ">=1.93.0" },
|
||||
|
||||
Reference in New Issue
Block a user