mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
wrote race-condition-free persistent NodeID-getting function
This commit is contained in:
2
.idea/externalDependencies.xml
generated
2
.idea/externalDependencies.xml
generated
@@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalDependencies">
|
||||
<plugin id="com.insyncwithfoo.pyright" />
|
||||
<plugin id="systems.fehn.intellijdirenv" />
|
||||
</component>
|
||||
</project>
|
||||
2
justfile
2
justfile
@@ -17,7 +17,7 @@ lint-check:
|
||||
uv run ruff check master worker shared engines/*
|
||||
|
||||
test:
|
||||
uv run pytest master worker shared engines/* rust/exo_pyo3_bindings/tests
|
||||
uv run pytest master worker shared engines/*
|
||||
|
||||
check:
|
||||
uv run basedpyright --project pyproject.toml
|
||||
|
||||
@@ -91,6 +91,10 @@ class Keypair:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
|
||||
class Multiaddr:
|
||||
r"""
|
||||
@@ -143,6 +147,12 @@ class PeerId:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
def __repr__(self) -> builtins.str:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
def __str__(self) -> builtins.str:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use libp2p::identity::{ecdsa, Keypair};
|
||||
use libp2p::PeerId;
|
||||
use pyo3::prelude::{PyBytesMethods, PyModule, PyModuleMethods};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
|
||||
use pyo3::{pyclass, pymethods, Bound, PyObject, PyResult, Python};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// TODO: documentation...
|
||||
@@ -76,6 +76,34 @@ impl PyKeypair {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// TODO: documentation
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// TODO: documentation...
|
||||
@@ -113,10 +141,12 @@ impl PyPeerId {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
/// TODO: documentation
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
/// TODO: documentation
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import multiprocessing.queues
|
||||
import pickle
|
||||
import time
|
||||
from collections.abc import Awaitable
|
||||
from typing import Callable
|
||||
@@ -48,7 +52,7 @@ async def test_discovery_callbacks() -> None:
|
||||
service.add_connected_callback(add_connected_callback)
|
||||
service.add_disconnected_callback(disconnected_callback)
|
||||
|
||||
for i in range(0, 10):
|
||||
for i in range(0, 1):
|
||||
print(f"PYTHON: tick {i} of 10")
|
||||
time.sleep(1)
|
||||
|
||||
@@ -67,6 +71,21 @@ def disconnected_callback(e: ConnectionUpdate) -> None:
|
||||
f"PYTHON: Disconnected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n")
|
||||
|
||||
|
||||
async def foobar(a: Callable[[str], Awaitable[str]]):
|
||||
abc = await a("")
|
||||
pass
|
||||
# async def foobar(a: Callable[[str], Awaitable[str]]):
|
||||
# abc = await a("")
|
||||
# pass
|
||||
|
||||
# def test_keypair_pickling() -> None:
|
||||
# def subprocess_task(kp: Keypair, q: multiprocessing.queues.Queue[Keypair]):
|
||||
# logging.info("a")
|
||||
# assert q.get() == kp
|
||||
# logging.info("b")
|
||||
#
|
||||
#
|
||||
# kp = Keypair.generate_ed25519()
|
||||
# q: multiprocessing.queues.Queue[Keypair] = multiprocessing.Queue()
|
||||
#
|
||||
# p = multiprocessing.Process(target=subprocess_task, args=(kp, q))
|
||||
# p.start()
|
||||
# q.put(kp)
|
||||
# p.join()
|
||||
@@ -9,6 +9,8 @@ EXO_WORKER_STATE = EXO_HOME / "worker_state.json"
|
||||
EXO_MASTER_LOG = EXO_HOME / "master.log"
|
||||
EXO_WORKER_LOG = EXO_HOME / "worker.log"
|
||||
|
||||
EXO_NODE_ID_KEYPAIR = EXO_HOME / "node_id.keypair"
|
||||
|
||||
EXO_WORKER_KEYRING_FILE = EXO_HOME / "worker_keyring"
|
||||
EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
|
||||
|
||||
|
||||
51
shared/node_id.py
Normal file
51
shared/node_id.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing.synchronize import Lock as LockT
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from exo_pyo3_bindings import Keypair
|
||||
|
||||
from shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
|
||||
"""
|
||||
This file is responsible for concurrent race-free persistent node-ID retrieval.
|
||||
"""
|
||||
|
||||
class _NodeIdGlobal(TypedDict):
|
||||
file_lock: LockT
|
||||
keypair: Optional[Keypair]
|
||||
|
||||
_NODE_ID_GLOBAL: _NodeIdGlobal = {
|
||||
"file_lock": Lock(),
|
||||
"keypair": None,
|
||||
}
|
||||
|
||||
def get_node_id_keypair() -> Keypair:
|
||||
"""
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
|
||||
# get from memory if we have it => read from file otherwise
|
||||
if _NODE_ID_GLOBAL["keypair"] is not None:
|
||||
return _NODE_ID_GLOBAL["keypair"]
|
||||
|
||||
# operate with cross-process lock to avoid race conditions
|
||||
with _NODE_ID_GLOBAL["file_lock"]:
|
||||
with open(EXO_NODE_ID_KEYPAIR, 'a+b') as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
_NODE_ID_GLOBAL["keypair"] = Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return _NODE_ID_GLOBAL["keypair"]
|
||||
except RuntimeError as e: # on runtime error, assume corrupt file
|
||||
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(EXO_NODE_ID_KEYPAIR, 'w+b') as f:
|
||||
_NODE_ID_GLOBAL["keypair"] = Keypair.generate_ed25519()
|
||||
f.write(_NODE_ID_GLOBAL["keypair"].to_protobuf_encoding())
|
||||
return _NODE_ID_GLOBAL["keypair"]
|
||||
@@ -15,7 +15,7 @@ dependencies = [
|
||||
"rustworkx>=0.16.0",
|
||||
"sqlmodel>=0.0.22",
|
||||
"sqlalchemy[asyncio]>=2.0.0",
|
||||
"greenlet>=3.2.3"
|
||||
"greenlet>=3.2.3",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@@ -41,3 +41,8 @@ dev = [
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
log_cli = true
|
||||
log_cli_level = "INFO"
|
||||
asyncio_mode = "auto"
|
||||
|
||||
85
shared/tests/test_node_id_persistence.py
Normal file
85
shared/tests/test_node_id_persistence.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from multiprocessing import Event, Process, Queue, Semaphore
|
||||
from multiprocessing.queues import Queue as QueueT
|
||||
from multiprocessing.synchronize import Event as EventT
|
||||
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
from typing import Optional
|
||||
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from shared.node_id import get_node_id_keypair
|
||||
|
||||
NUM_CONCURRENT_PROCS = 10
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
assert num_procs > 0
|
||||
|
||||
def subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
|
||||
# synchronise with parent process
|
||||
logging.info(msg=f"SUBPROCESS {pid}: Started")
|
||||
sem.release()
|
||||
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
logging.info(msg=f"SUBPROCESS {pid}: Reading start")
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
logging.info(msg=f"SUBPROCESS {pid}: Reading end")
|
||||
|
||||
# notify master of finishing
|
||||
sem.release()
|
||||
|
||||
sem = Semaphore(0)
|
||||
ev = Event()
|
||||
queue: QueueT[bytes] = Queue(maxsize=num_procs)
|
||||
|
||||
# make parent process wait for all subprocesses to start
|
||||
logging.info(msg=f"PARENT: Starting {num_procs} subprocesses")
|
||||
for i in range(num_procs):
|
||||
Process(target=subprocess_task, args=(i + 1, sem, ev, queue)).start()
|
||||
for _ in range(num_procs):
|
||||
sem.acquire()
|
||||
|
||||
# start all the sub processes simultaneously
|
||||
logging.info(msg="PARENT: Beginning read")
|
||||
ev.set()
|
||||
|
||||
# wait until all subprocesses are done & read results
|
||||
for _ in range(num_procs):
|
||||
sem.acquire()
|
||||
|
||||
# check that the input/output order match, and that
|
||||
# all subprocesses end up reading the same file
|
||||
logging.info(msg="PARENT: Checking consistency")
|
||||
keypair: Optional[bytes] = None
|
||||
assert queue.qsize() > 0
|
||||
while queue.qsize() > 0:
|
||||
temp_keypair = queue.get()
|
||||
if keypair is None:
|
||||
keypair = temp_keypair
|
||||
else:
|
||||
assert keypair == temp_keypair
|
||||
return keypair # pyright: ignore[reportReturnType]
|
||||
|
||||
def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(p)
|
||||
|
||||
def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
# delete current file and write a new one
|
||||
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
||||
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
|
||||
with caplog.at_level(logging.CRITICAL): # supress logs
|
||||
# make sure that continuous fetches return the same value
|
||||
for _ in range(reps):
|
||||
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
|
||||
# make sure that after deleting, we are not fetching the same value
|
||||
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
||||
for _ in range(reps):
|
||||
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
Reference in New Issue
Block a user