mirror of
https://github.com/exo-explore/exo.git
synced 2026-06-24 05:48:57 -04:00
100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
from multiprocessing import Event, Queue, Semaphore
|
|
from multiprocessing.process import BaseProcess
|
|
from multiprocessing.queues import Queue as QueueT
|
|
from multiprocessing.synchronize import Event as EventT
|
|
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
|
from typing import Optional
|
|
|
|
from pytest import LogCaptureFixture
|
|
|
|
from shared.constants import EXO_NODE_ID_KEYPAIR
|
|
from shared.utils import get_node_id_keypair
|
|
|
|
NUM_CONCURRENT_PROCS = 10
|
|
|
|
|
|
def _get_keypair_concurrent_subprocess_task(pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]) -> None:
|
|
try:
|
|
# synchronise with parent process
|
|
logging.info(msg=f"SUBPROCESS {pid}: Started")
|
|
sem.release()
|
|
|
|
# wait to be told to begin simultaneous read
|
|
ev.wait()
|
|
logging.info(msg=f"SUBPROCESS {pid}: Reading start")
|
|
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
|
logging.info(msg=f"SUBPROCESS {pid}: Reading end")
|
|
except Exception as e:
|
|
logging.error(msg=f"SUBPROCESS {pid}: Error encountered: {e}")
|
|
|
|
|
|
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
|
assert num_procs > 0
|
|
|
|
sem = Semaphore(0)
|
|
ev = Event()
|
|
queue: QueueT[bytes] = Queue(maxsize=num_procs)
|
|
|
|
# make parent process wait for all subprocesses to start
|
|
logging.info(msg=f"PARENT: Starting {num_procs} subprocesses")
|
|
ps: list[BaseProcess] = []
|
|
for i in range(num_procs):
|
|
p = multiprocessing.get_context("fork").Process(target=_get_keypair_concurrent_subprocess_task,
|
|
args=(i + 1, sem, ev, queue))
|
|
ps.append(p)
|
|
p.start()
|
|
for _ in range(num_procs):
|
|
sem.acquire()
|
|
|
|
# start all the sub processes simultaneously
|
|
logging.info(msg="PARENT: Beginning read")
|
|
ev.set()
|
|
|
|
# wait until all subprocesses are done & read results
|
|
for p in ps:
|
|
p.join()
|
|
|
|
# check that the input/output order match, and that
|
|
# all subprocesses end up reading the same file
|
|
logging.info(msg="PARENT: Checking consistency")
|
|
keypair: Optional[bytes] = None
|
|
qsize = 0 # cannot use Queue.qsize due to MacOS incompatibility :(
|
|
while not queue.empty():
|
|
qsize += 1
|
|
temp_keypair = queue.get()
|
|
if keypair is None:
|
|
keypair = temp_keypair
|
|
else:
|
|
assert keypair == temp_keypair
|
|
assert num_procs == qsize
|
|
return keypair # pyright: ignore[reportReturnType]
|
|
|
|
|
|
def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
|
with contextlib.suppress(OSError):
|
|
os.remove(p)
|
|
|
|
|
|
def test_node_id_fetching(caplog: LogCaptureFixture):
|
|
reps = 10
|
|
|
|
# delete current file and write a new one
|
|
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
|
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
|
|
|
with caplog.at_level(logging.CRITICAL): # supress logs
|
|
# make sure that continuous fetches return the same value
|
|
for _ in range(reps):
|
|
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
|
|
|
# make sure that after deleting, we are not fetching the same value
|
|
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
|
for _ in range(reps):
|
|
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|