Files
exo/shared/tests/test_node_id_persistence.py
Andrei Cravtov b687dec6b2 Discovery integration master
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
2025-07-27 13:43:59 +01:00

100 lines
3.3 KiB
Python

from __future__ import annotations
import contextlib
import logging
import multiprocessing
import os
from multiprocessing import Event, Queue, Semaphore
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue as QueueT
from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from typing import Optional
from pytest import LogCaptureFixture
from shared.constants import EXO_NODE_ID_KEYPAIR
from shared.utils import get_node_id_keypair
NUM_CONCURRENT_PROCS = 10
def _get_keypair_concurrent_subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
try:
# synchronise with parent process
logging.info(msg=f"SUBPROCESS {pid}: Started")
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
logging.info(msg=f"SUBPROCESS {pid}: Reading start")
queue.put(get_node_id_keypair().to_protobuf_encoding())
logging.info(msg=f"SUBPROCESS {pid}: Reading end")
except Exception as e:
logging.error(msg=f"SUBPROCESS {pid}: Error encountered: {e}")
def _get_keypair_concurrent(num_procs: int) -> bytes:
assert num_procs > 0
sem = Semaphore(0)
ev = Event()
queue: QueueT[bytes] = Queue(maxsize=num_procs)
# make parent process wait for all subprocesses to start
logging.info(msg=f"PARENT: Starting {num_procs} subprocesses")
ps: list[BaseProcess] = []
for i in range(num_procs):
p = multiprocessing.get_context("fork").Process(target=_get_keypair_concurrent_subprocess_task,
args=(i + 1, sem, ev, queue))
ps.append(p)
p.start()
for _ in range(num_procs):
sem.acquire()
# start all the sub processes simultaneously
logging.info(msg="PARENT: Beginning read")
ev.set()
# wait until all subprocesses are done & read results
for p in ps:
p.join()
# check that the input/output order match, and that
# all subprocesses end up reading the same file
logging.info(msg="PARENT: Checking consistency")
keypair: Optional[bytes] = None
qsize = 0 # cannot use Queue.qsize due to MacOS incompatibility :(
while not queue.empty():
qsize += 1
temp_keypair = queue.get()
if keypair is None:
keypair = temp_keypair
else:
assert keypair == temp_keypair
assert num_procs == qsize
return keypair # pyright: ignore[reportReturnType]
def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
with contextlib.suppress(OSError):
os.remove(p)
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10
# delete current file and write a new one
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
with caplog.at_level(logging.CRITICAL): # supress logs
# make sure that continuous fetches return the same value
for _ in range(reps):
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
# make sure that after deleting, we are not fetching the same value
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
for _ in range(reps):
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)