wrote race-condition-free persistent NodeID-getting function

This commit is contained in:
Andrei Cravtov
2025-07-23 20:18:56 +01:00
committed by GitHub
parent 7a452c3351
commit 3ab5609289
9 changed files with 212 additions and 10 deletions

View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalDependencies">
<plugin id="com.insyncwithfoo.pyright" />
<plugin id="systems.fehn.intellijdirenv" />
</component>
</project>

View File

@@ -17,7 +17,7 @@ lint-check:
uv run ruff check master worker shared engines/*
test:
uv run pytest master worker shared engines/* rust/exo_pyo3_bindings/tests
uv run pytest master worker shared engines/*
check:
uv run basedpyright --project pyproject.toml

View File

@@ -91,6 +91,10 @@ class Keypair:
r"""
TODO: documentation
"""
def to_peer_id(self) -> PeerId:
r"""
TODO: documentation
"""
class Multiaddr:
r"""
@@ -143,6 +147,12 @@ class PeerId:
r"""
TODO: documentation
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
def __repr__(self) -> builtins.str:
r"""
TODO: documentation
"""
def __str__(self) -> builtins.str:
r"""
TODO: documentation
"""

View File

@@ -3,7 +3,7 @@ use libp2p::identity::{ecdsa, Keypair};
use libp2p::PeerId;
use pyo3::prelude::{PyBytesMethods, PyModule, PyModuleMethods};
use pyo3::types::PyBytes;
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
use pyo3::{pyclass, pymethods, Bound, PyObject, PyResult, Python};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// TODO: documentation...
@@ -76,6 +76,34 @@ impl PyKeypair {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// TODO: documentation
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// TODO: documentation...
@@ -113,10 +141,12 @@ impl PyPeerId {
self.0.to_base58()
}
/// TODO: documentation
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
/// TODO: documentation
fn __str__(&self) -> String {
self.to_base58()
}

View File

@@ -1,3 +1,7 @@
import logging
import multiprocessing
import multiprocessing.queues
import pickle
import time
from collections.abc import Awaitable
from typing import Callable
@@ -48,7 +52,7 @@ async def test_discovery_callbacks() -> None:
service.add_connected_callback(add_connected_callback)
service.add_disconnected_callback(disconnected_callback)
for i in range(0, 10):
for i in range(0, 1):
print(f"PYTHON: tick {i} of 10")
time.sleep(1)
@@ -67,6 +71,21 @@ def disconnected_callback(e: ConnectionUpdate) -> None:
f"PYTHON: Disconnected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n")
async def foobar(a: Callable[[str], Awaitable[str]]):
abc = await a("")
pass
# async def foobar(a: Callable[[str], Awaitable[str]]):
# abc = await a("")
# pass
# def test_keypair_pickling() -> None:
# def subprocess_task(kp: Keypair, q: multiprocessing.queues.Queue[Keypair]):
# logging.info("a")
# assert q.get() == kp
# logging.info("b")
#
#
# kp = Keypair.generate_ed25519()
# q: multiprocessing.queues.Queue[Keypair] = multiprocessing.Queue()
#
# p = multiprocessing.Process(target=subprocess_task, args=(kp, q))
# p.start()
# q.put(kp)
# p.join()

View File

@@ -9,6 +9,8 @@ EXO_WORKER_STATE = EXO_HOME / "worker_state.json"
EXO_MASTER_LOG = EXO_HOME / "master.log"
EXO_WORKER_LOG = EXO_HOME / "worker.log"
EXO_NODE_ID_KEYPAIR = EXO_HOME / "node_id.keypair"
EXO_WORKER_KEYRING_FILE = EXO_HOME / "worker_keyring"
EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"

51
shared/node_id.py Normal file
View File

@@ -0,0 +1,51 @@
import logging
from multiprocessing import Lock
from multiprocessing.synchronize import Lock as LockT
from typing import Optional, TypedDict
from exo_pyo3_bindings import Keypair
from shared.constants import EXO_NODE_ID_KEYPAIR
"""
This file is responsible for concurrent race-free persistent node-ID retrieval.
"""
class _NodeIdGlobal(TypedDict):
file_lock: LockT
keypair: Optional[Keypair]
_NODE_ID_GLOBAL: _NodeIdGlobal = {
"file_lock": Lock(),
"keypair": None,
}
def get_node_id_keypair() -> Keypair:
"""
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# get from memory if we have it => read from file otherwise
if _NODE_ID_GLOBAL["keypair"] is not None:
return _NODE_ID_GLOBAL["keypair"]
# operate with cross-process lock to avoid race conditions
with _NODE_ID_GLOBAL["file_lock"]:
with open(EXO_NODE_ID_KEYPAIR, 'a+b') as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
_NODE_ID_GLOBAL["keypair"] = Keypair.from_protobuf_encoding(protobuf_encoded)
return _NODE_ID_GLOBAL["keypair"]
except RuntimeError as e: # on runtime error, assume corrupt file
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(EXO_NODE_ID_KEYPAIR, 'w+b') as f:
_NODE_ID_GLOBAL["keypair"] = Keypair.generate_ed25519()
f.write(_NODE_ID_GLOBAL["keypair"].to_protobuf_encoding())
return _NODE_ID_GLOBAL["keypair"]

View File

@@ -15,7 +15,7 @@ dependencies = [
"rustworkx>=0.16.0",
"sqlmodel>=0.0.22",
"sqlalchemy[asyncio]>=2.0.0",
"greenlet>=3.2.3"
"greenlet>=3.2.3",
]
[build-system]
@@ -41,3 +41,8 @@ dev = [
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
]
[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
asyncio_mode = "auto"

View File

@@ -0,0 +1,85 @@
import contextlib
import logging
import os
from multiprocessing import Event, Process, Queue, Semaphore
from multiprocessing.queues import Queue as QueueT
from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from typing import Optional
from pytest import LogCaptureFixture
from shared.constants import EXO_NODE_ID_KEYPAIR
from shared.node_id import get_node_id_keypair
NUM_CONCURRENT_PROCS = 10
def _get_keypair_concurrent(num_procs: int) -> bytes:
assert num_procs > 0
def subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
# synchronise with parent process
logging.info(msg=f"SUBPROCESS {pid}: Started")
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
logging.info(msg=f"SUBPROCESS {pid}: Reading start")
queue.put(get_node_id_keypair().to_protobuf_encoding())
logging.info(msg=f"SUBPROCESS {pid}: Reading end")
# notify master of finishing
sem.release()
sem = Semaphore(0)
ev = Event()
queue: QueueT[bytes] = Queue(maxsize=num_procs)
# make parent process wait for all subprocesses to start
logging.info(msg=f"PARENT: Starting {num_procs} subprocesses")
for i in range(num_procs):
Process(target=subprocess_task, args=(i + 1, sem, ev, queue)).start()
for _ in range(num_procs):
sem.acquire()
# start all the sub processes simultaneously
logging.info(msg="PARENT: Beginning read")
ev.set()
# wait until all subprocesses are done & read results
for _ in range(num_procs):
sem.acquire()
# check that the input/output order match, and that
# all subprocesses end up reading the same file
logging.info(msg="PARENT: Checking consistency")
keypair: Optional[bytes] = None
assert queue.qsize() > 0
while queue.qsize() > 0:
temp_keypair = queue.get()
if keypair is None:
keypair = temp_keypair
else:
assert keypair == temp_keypair
return keypair # pyright: ignore[reportReturnType]
def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
with contextlib.suppress(OSError):
os.remove(p)
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10
# delete current file and write a new one
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
with caplog.at_level(logging.CRITICAL): # supress logs
# make sure that continuous fetches return the same value
for _ in range(reps):
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
# make sure that after deleting, we are not fetching the same value
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
for _ in range(reps):
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)