mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Discovery integration master
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
12
.idea/exo-v2.iml
generated
12
.idea/exo-v2.iml
generated
@@ -1,5 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="EMPTY_MODULE" version="4">
|
||||
<component name="FacetManager">
|
||||
<facet type="Python" name="Python facet">
|
||||
<configuration sdkName="Python 3.13 virtualenv at ~/Desktop/exo/.venv" />
|
||||
</facet>
|
||||
</component>
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/discovery/src" isTestSource="false" />
|
||||
@@ -11,10 +16,17 @@
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/util/fn_pipe/proc/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/util/fn_pipe/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/util/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/engines/mlx" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/master" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/shared" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/worker" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/rust/target" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.direnv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/build" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.13 (exo)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="Python 3.13 virtualenv at ~/Desktop/exo/.venv interpreter library" level="application" />
|
||||
</component>
|
||||
</module>
|
||||
132
master/discovery_supervisor.py
Normal file
132
master/discovery_supervisor.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, DiscoveryService, Keypair
|
||||
|
||||
from shared.db import AsyncSQLiteEventStorage
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import TopologyEdgeCreated, TopologyEdgeDeleted
|
||||
from shared.types.topology import Connection
|
||||
|
||||
|
||||
class DiscoverySupervisor:
|
||||
def __init__(self, node_id_keypair: Keypair, node_id: NodeId, global_events: AsyncSQLiteEventStorage,
|
||||
logger: logging.Logger):
|
||||
self.global_events = global_events
|
||||
self.logger = logger
|
||||
self.node_id = node_id
|
||||
|
||||
# configure callbacks
|
||||
self.discovery_service = DiscoveryService(node_id_keypair)
|
||||
self._add_connected_callback()
|
||||
self._add_disconnected_callback()
|
||||
|
||||
def _add_connected_callback(self):
|
||||
stream_get, stream_put = _make_iter()
|
||||
self.discovery_service.add_connected_callback(stream_put)
|
||||
|
||||
async def run():
|
||||
async for c in stream_get:
|
||||
await self._connected_callback(c)
|
||||
|
||||
return asyncio.create_task(run())
|
||||
|
||||
def _add_disconnected_callback(self):
|
||||
stream_get, stream_put = _make_iter()
|
||||
|
||||
async def run():
|
||||
async for c in stream_get:
|
||||
await self._disconnected_callback(c)
|
||||
|
||||
self.discovery_service.add_disconnected_callback(stream_put)
|
||||
return asyncio.create_task(run())
|
||||
|
||||
async def _connected_callback(self, e: ConnectionUpdate) -> None:
|
||||
local_node_id = self.node_id
|
||||
send_back_node_id = NodeId(e.peer_id.to_base58())
|
||||
local_multiaddr = e.local_addr.to_string()
|
||||
send_back_multiaddr = e.send_back_addr.to_string()
|
||||
connection_profile = None
|
||||
|
||||
topology_edge_created = TopologyEdgeCreated(edge=Connection(
|
||||
local_node_id=local_node_id,
|
||||
send_back_node_id=send_back_node_id,
|
||||
local_multiaddr=local_multiaddr,
|
||||
send_back_multiaddr=send_back_multiaddr,
|
||||
connection_profile=connection_profile
|
||||
))
|
||||
self.logger.error(
|
||||
msg=f"CONNECTED CALLBACK: {local_node_id} -> {send_back_node_id}, {local_multiaddr} -> {send_back_multiaddr}")
|
||||
await self.global_events.append_events(
|
||||
[topology_edge_created],
|
||||
self.node_id
|
||||
)
|
||||
|
||||
async def _disconnected_callback(self, e: ConnectionUpdate) -> None:
|
||||
local_node_id = self.node_id
|
||||
send_back_node_id = NodeId(e.peer_id.to_base58())
|
||||
local_multiaddr = e.local_addr.to_string()
|
||||
send_back_multiaddr = e.send_back_addr.to_string()
|
||||
connection_profile = None
|
||||
|
||||
topology_edge_created = TopologyEdgeDeleted(edge=Connection(
|
||||
local_node_id=local_node_id,
|
||||
send_back_node_id=send_back_node_id,
|
||||
local_multiaddr=local_multiaddr,
|
||||
send_back_multiaddr=send_back_multiaddr,
|
||||
connection_profile=connection_profile
|
||||
))
|
||||
self.logger.error(
|
||||
msg=f"DISCONNECTED CALLBACK: {local_node_id} -> {send_back_node_id}, {local_multiaddr} -> {send_back_multiaddr}")
|
||||
await self.global_events.append_events(
|
||||
[topology_edge_created],
|
||||
self.node_id
|
||||
)
|
||||
|
||||
|
||||
def _make_iter(): # TODO: generalize to generic utility
|
||||
loop = asyncio.get_event_loop()
|
||||
queue: asyncio.Queue[ConnectionUpdate] = asyncio.Queue()
|
||||
|
||||
def put(c: ConnectionUpdate) -> None:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, c)
|
||||
|
||||
async def get():
|
||||
while True:
|
||||
yield await queue.get()
|
||||
|
||||
return get(), put
|
||||
|
||||
# class MyClass: # TODO: figure out how to make pydantic integrate with Multiaddr
|
||||
# def __init__(self, data: str):
|
||||
# self.data = data
|
||||
#
|
||||
# @staticmethod
|
||||
# def from_str(s: str, _i: ValidationInfo) -> 'MyClass':
|
||||
# return MyClass(s)
|
||||
#
|
||||
# def __str__(self):
|
||||
# return self.data
|
||||
#
|
||||
# @classmethod
|
||||
# def __get_pydantic_core_schema__(
|
||||
# cls, source_type: type[any], handler: GetCoreSchemaHandler
|
||||
# ) -> CoreSchema:
|
||||
# return core_schema.with_info_after_validator_function(
|
||||
# function=MyClass.from_str,
|
||||
# schema=core_schema.bytes_schema(),
|
||||
# serialization=core_schema.to_string_ser_schema()
|
||||
# )
|
||||
#
|
||||
#
|
||||
# # Use directly in a model (no Annotated needed)
|
||||
# class ExampleModel(BaseModel):
|
||||
# field: MyClass
|
||||
#
|
||||
#
|
||||
# m = ExampleModel(field=MyClass("foo"))
|
||||
# d = m.model_dump()
|
||||
# djs = m.model_dump_json()
|
||||
#
|
||||
# print(d)
|
||||
# print(djs)
|
||||
@@ -6,7 +6,10 @@ import traceback
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from exo_pyo3_bindings import Keypair
|
||||
|
||||
from master.api import start_fastapi_server
|
||||
from master.discovery_supervisor import DiscoverySupervisor
|
||||
from master.election_callback import ElectionCallbacks
|
||||
from master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
|
||||
from master.placement import get_instance_placements, get_transition_events
|
||||
@@ -14,7 +17,6 @@ from shared.apply import apply
|
||||
from shared.db.sqlite.config import EventLogConfig
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.node_id import get_node_id_keypair
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
Event,
|
||||
@@ -30,14 +32,23 @@ from shared.types.events.commands import (
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
|
||||
from shared.types.worker.instances import Instance
|
||||
from shared.utils import get_node_id_keypair
|
||||
|
||||
|
||||
class Master:
|
||||
def __init__(self, node_id: NodeId, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, worker_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: logging.Logger):
|
||||
def __init__(self, node_id_keypair: Keypair, node_id: NodeId, command_buffer: list[Command],
|
||||
global_events: AsyncSQLiteEventStorage, worker_events: AsyncSQLiteEventStorage,
|
||||
forwarder_binary_path: Path, logger: logging.Logger):
|
||||
self.node_id = node_id
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
self.worker_events = worker_events
|
||||
self.discovery_supervisor = DiscoverySupervisor(
|
||||
node_id_keypair,
|
||||
node_id,
|
||||
global_events,
|
||||
logger
|
||||
)
|
||||
self.forwarder_supervisor = ForwarderSupervisor(
|
||||
forwarder_binary_path=forwarder_binary_path,
|
||||
logger=logger
|
||||
@@ -128,7 +139,6 @@ class Master:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
logger = logging.getLogger('master_logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@@ -163,8 +173,10 @@ async def main():
|
||||
api_thread.start()
|
||||
logger.info('Running FastAPI server in a separate thread. Listening on port 8000.')
|
||||
|
||||
master = Master(node_id, command_buffer, global_events, worker_events, forwarder_binary_path=Path("./build/forwarder"), logger=logger)
|
||||
master = Master(node_id_keypair, node_id, command_buffer, global_events, worker_events,
|
||||
forwarder_binary_path=Path("./build/forwarder"), logger=logger)
|
||||
await master.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -15,17 +15,17 @@ def create_node():
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return Node(
|
||||
node_id=node_id,
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(flops_fp16=1000)
|
||||
)
|
||||
)
|
||||
@@ -37,10 +37,11 @@ def create_node():
|
||||
def create_connection():
|
||||
def _create_connection(source_node_id: NodeId, sink_node_id: NodeId) -> Connection:
|
||||
return Connection(
|
||||
source_node_id=source_node_id,
|
||||
sink_node_id=sink_node_id,
|
||||
source_multiaddr="/ip4/127.0.0.1/tcp/1234",
|
||||
sink_multiaddr="/ip4/127.0.0.1/tcp/1235",
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
local_multiaddr="/ip4/127.0.0.1/tcp/1234",
|
||||
send_back_multiaddr="/ip4/127.0.0.1/tcp/1235",
|
||||
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
)
|
||||
return _create_connection
|
||||
|
||||
return _create_connection
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from exo_pyo3_bindings import Keypair
|
||||
|
||||
from master.main import Master
|
||||
from shared.db.sqlite.config import EventLogConfig
|
||||
@@ -38,8 +39,10 @@ async def test_master():
|
||||
|
||||
forwarder_binary_path = _create_forwarder_dummy_binary()
|
||||
|
||||
node_id = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
master = Master(node_id, command_buffer=command_buffer, global_events=global_events, worker_events=event_log_manager.worker_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
|
||||
node_id_keypair = Keypair.generate_ed25519()
|
||||
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
|
||||
master = Master(node_id_keypair, node_id, command_buffer=command_buffer, global_events=global_events,
|
||||
forwarder_binary_path=forwarder_binary_path, logger=logger, worker_events=event_log_manager.worker_events)
|
||||
asyncio.create_task(master.run())
|
||||
|
||||
command_buffer.append(
|
||||
|
||||
@@ -13,20 +13,27 @@ from shared.types.topology import Connection, ConnectionProfile, Node, NodeId
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection() -> Connection:
|
||||
return Connection(source_node_id=NodeId(), sink_node_id=NodeId(), source_multiaddr="/ip4/127.0.0.1/tcp/1234", sink_multiaddr="/ip4/127.0.0.1/tcp/1235", connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000))
|
||||
return Connection(local_node_id=NodeId(), send_back_node_id=NodeId(), local_multiaddr="/ip4/127.0.0.1/tcp/1234",
|
||||
send_back_multiaddr="/ip4/127.0.0.1/tcp/1235",
|
||||
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000)
|
||||
system_profile = SystemPerformanceProfile(flops_fp16=1000)
|
||||
return NodePerformanceProfile(model_id="test", chip_id="test", memory=memory_profile, network_interfaces=[], system=system_profile)
|
||||
return NodePerformanceProfile(model_id="test", chip_id="test", memory=memory_profile, network_interfaces=[],
|
||||
system=system_profile)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_profile() -> ConnectionProfile:
|
||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
# arrange
|
||||
node_id = NodeId()
|
||||
@@ -41,39 +48,47 @@ def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
|
||||
def test_add_connection(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
data = topology.get_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
assert data == connection.connection_profile
|
||||
assert data == connection.connection_profile
|
||||
|
||||
|
||||
def test_update_node_profile(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
|
||||
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test",
|
||||
memory=MemoryPerformanceProfile(ram_total=1000, ram_available=1000,
|
||||
swap_total=1000, swap_available=1000),
|
||||
network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
|
||||
|
||||
# act
|
||||
topology.update_node_profile(connection.source_node_id, node_profile=new_node_profile)
|
||||
topology.update_node_profile(connection.local_node_id, node_profile=new_node_profile)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.source_node_id)
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
def test_update_connection_profile(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(throughput=2000, latency=2000, jitter=2000)
|
||||
connection = Connection(source_node_id=connection.source_node_id, sink_node_id=connection.sink_node_id, source_multiaddr=connection.source_multiaddr, sink_multiaddr=connection.sink_multiaddr, connection_profile=new_connection_profile)
|
||||
connection = Connection(local_node_id=connection.local_node_id, send_back_node_id=connection.send_back_node_id,
|
||||
local_multiaddr=connection.local_multiaddr,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile)
|
||||
|
||||
# act
|
||||
topology.update_connection_profile(connection)
|
||||
@@ -82,10 +97,12 @@ def test_update_connection_profile(topology: Topology, node_profile: NodePerform
|
||||
data = topology.get_connection_profile(connection)
|
||||
assert data == new_connection_profile
|
||||
|
||||
def test_remove_connection_still_connected(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
|
||||
|
||||
def test_remove_connection_still_connected(topology: Topology, node_profile: NodePerformanceProfile,
|
||||
connection: Connection):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -94,7 +111,8 @@ def test_remove_connection_still_connected(topology: Topology, node_profile: Nod
|
||||
# assert
|
||||
with pytest.raises(IndexError):
|
||||
topology.get_connection_profile(connection)
|
||||
|
||||
|
||||
|
||||
def test_remove_connection_bridge(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
|
||||
"""Create a bridge scenario: master -> node_a -> node_b
|
||||
and remove the bridge connection (master -> node_a)"""
|
||||
@@ -102,63 +120,63 @@ def test_remove_connection_bridge(topology: Topology, node_profile: NodePerforma
|
||||
master_id = NodeId()
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
|
||||
|
||||
topology.add_node(Node(node_id=master_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=node_a_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=node_b_id, node_profile=node_profile))
|
||||
|
||||
|
||||
connection_master_to_a = Connection(
|
||||
source_node_id=master_id,
|
||||
sink_node_id=node_a_id,
|
||||
source_multiaddr="/ip4/127.0.0.1/tcp/1234",
|
||||
sink_multiaddr="/ip4/127.0.0.1/tcp/1235",
|
||||
local_node_id=master_id,
|
||||
send_back_node_id=node_a_id,
|
||||
local_multiaddr="/ip4/127.0.0.1/tcp/1234",
|
||||
send_back_multiaddr="/ip4/127.0.0.1/tcp/1235",
|
||||
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
)
|
||||
|
||||
|
||||
connection_a_to_b = Connection(
|
||||
source_node_id=node_a_id,
|
||||
sink_node_id=node_b_id,
|
||||
source_multiaddr="/ip4/127.0.0.1/tcp/1236",
|
||||
sink_multiaddr="/ip4/127.0.0.1/tcp/1237",
|
||||
local_node_id=node_a_id,
|
||||
send_back_node_id=node_b_id,
|
||||
local_multiaddr="/ip4/127.0.0.1/tcp/1236",
|
||||
send_back_multiaddr="/ip4/127.0.0.1/tcp/1237",
|
||||
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
)
|
||||
|
||||
|
||||
topology.add_connection(connection_master_to_a)
|
||||
topology.add_connection(connection_a_to_b)
|
||||
|
||||
|
||||
assert len(list(topology.list_nodes())) == 3
|
||||
|
||||
|
||||
topology.remove_connection(connection_master_to_a)
|
||||
|
||||
|
||||
remaining_nodes = list(topology.list_nodes())
|
||||
assert len(remaining_nodes) == 1
|
||||
assert remaining_nodes[0].node_id == master_id
|
||||
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
topology.get_node_profile(node_a_id)
|
||||
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
topology.get_node_profile(node_b_id)
|
||||
|
||||
|
||||
def test_remove_node_still_connected(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_node(connection.source_node_id)
|
||||
topology.remove_node(connection.local_node_id)
|
||||
|
||||
# assert
|
||||
with pytest.raises(KeyError):
|
||||
topology.get_node_profile(connection.source_node_id)
|
||||
topology.get_node_profile(connection.local_node_id)
|
||||
|
||||
|
||||
def test_list_nodes(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -167,4 +185,4 @@ def test_list_nodes(topology: Topology, node_profile: NodePerformanceProfile, co
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, Node) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {connection.source_node_id, connection.sink_node_id}
|
||||
assert {node.node_id for node in nodes} == {connection.local_node_id, connection.send_back_node_id}
|
||||
|
||||
@@ -76,7 +76,7 @@ libp2p-tcp = "0.44"
|
||||
# interop
|
||||
pyo3 = "0.25"
|
||||
#pyo3-stub-gen = { git = "https://github.com/Jij-Inc/pyo3-stub-gen.git", rev = "d2626600e52452e71095c57e721514de748d419d" } # v0.11 not yet published to crates
|
||||
pyo3-stub-gen = { git = "https://github.com/cstruct/pyo3-stub-gen.git", rev = "2efddde7dcffc462868aa0e4bbc46877c657a0fe" } # This fork adds support for type overrides => not merged yet!!!
|
||||
pyo3-stub-gen = { git = "https://github.com/cstruct/pyo3-stub-gen.git", rev = "a935099276fa2d273496a2759d4af7177a6acd57" } # This fork adds support for type overrides => not merged yet!!!
|
||||
pyo3-async-runtimes = "0.25"
|
||||
|
||||
[workspace.lints.rust]
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
use crate::alias::AnyResult;
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity, mdns};
|
||||
use libp2p::core::Endpoint;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::swarm::derive_prelude::Either;
|
||||
use libp2p::swarm::{
|
||||
ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm,
|
||||
NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, gossipsub, identity, mdns};
|
||||
use std::fmt;
|
||||
use std::fmt::Debug;
|
||||
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -12,8 +20,183 @@ pub struct DiscoveryBehaviour {
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
// #[doc = "`NetworkBehaviour::ToSwarm` produced by DiscoveryBehaviour."]
|
||||
// pub enum DiscoveryBehaviourEvent {
|
||||
// Mdns(<mdns::tokio::Behaviour as NetworkBehaviour>::ToSwarm),
|
||||
// Gossipsub(<gossipsub::Behaviour as NetworkBehaviour>::ToSwarm),
|
||||
// }
|
||||
// impl Debug for DiscoveryBehaviourEvent
|
||||
// where
|
||||
// <mdns::tokio::Behaviour as NetworkBehaviour>::ToSwarm: Debug,
|
||||
// <gossipsub::Behaviour as NetworkBehaviour>::ToSwarm: Debug,
|
||||
// {
|
||||
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
// match &self {
|
||||
// DiscoveryBehaviourEvent::Mdns(event) => {
|
||||
// f.write_fmt(format_args!("{}: {:?}", "DiscoveryBehaviourEvent", event))
|
||||
// }
|
||||
// DiscoveryBehaviourEvent::Gossipsub(event) => {
|
||||
// f.write_fmt(format_args!("{}: {:?}", "DiscoveryBehaviourEvent", event))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// impl NetworkBehaviour for DiscoveryBehaviour
|
||||
// where
|
||||
// mdns::tokio::Behaviour: NetworkBehaviour,
|
||||
// gossipsub::Behaviour: NetworkBehaviour,
|
||||
// {
|
||||
// type ConnectionHandler =
|
||||
// ConnectionHandlerSelect<THandler<mdns::tokio::Behaviour>, THandler<gossipsub::Behaviour>>;
|
||||
// type ToSwarm = DiscoveryBehaviourEvent;
|
||||
// #[allow(clippy::needless_question_mark)]
|
||||
// fn handle_pending_inbound_connection(
|
||||
// &mut self,
|
||||
// connection_id: ConnectionId,
|
||||
// local_addr: &Multiaddr,
|
||||
// remote_addr: &Multiaddr,
|
||||
// ) -> Result<(), ConnectionDenied> {
|
||||
// NetworkBehaviour::handle_pending_inbound_connection(
|
||||
// &mut self.mdns,
|
||||
// connection_id,
|
||||
// local_addr,
|
||||
// remote_addr,
|
||||
// )?;
|
||||
// NetworkBehaviour::handle_pending_inbound_connection(
|
||||
// &mut self.gossipsub,
|
||||
// connection_id,
|
||||
// local_addr,
|
||||
// remote_addr,
|
||||
// )?;
|
||||
// Ok(())
|
||||
// }
|
||||
// #[allow(clippy::needless_question_mark)]
|
||||
// fn handle_established_inbound_connection(
|
||||
// &mut self,
|
||||
// connection_id: ConnectionId,
|
||||
// peer: PeerId,
|
||||
// local_addr: &Multiaddr,
|
||||
// remote_addr: &Multiaddr,
|
||||
// ) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
// Ok(ConnectionHandler::select(
|
||||
// self.mdns.handle_established_inbound_connection(
|
||||
// connection_id,
|
||||
// peer,
|
||||
// local_addr,
|
||||
// remote_addr,
|
||||
// )?,
|
||||
// self.gossipsub.handle_established_inbound_connection(
|
||||
// connection_id,
|
||||
// peer,
|
||||
// local_addr,
|
||||
// remote_addr,
|
||||
// )?,
|
||||
// ))
|
||||
// }
|
||||
// #[allow(clippy::needless_question_mark)]
|
||||
// fn handle_pending_outbound_connection(
|
||||
// &mut self,
|
||||
// connection_id: ConnectionId,
|
||||
// maybe_peer: Option<PeerId>,
|
||||
// addresses: &[Multiaddr],
|
||||
// effective_role: Endpoint,
|
||||
// ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
|
||||
// let mut combined_addresses = Vec::new();
|
||||
// combined_addresses.extend(NetworkBehaviour::handle_pending_outbound_connection(
|
||||
// &mut self.mdns,
|
||||
// connection_id,
|
||||
// maybe_peer,
|
||||
// addresses,
|
||||
// effective_role,
|
||||
// )?);
|
||||
// combined_addresses.extend(NetworkBehaviour::handle_pending_outbound_connection(
|
||||
// &mut self.gossipsub,
|
||||
// connection_id,
|
||||
// maybe_peer,
|
||||
// addresses,
|
||||
// effective_role,
|
||||
// )?);
|
||||
// Ok(combined_addresses)
|
||||
// }
|
||||
// #[allow(clippy::needless_question_mark)]
|
||||
// fn handle_established_outbound_connection(
|
||||
// &mut self,
|
||||
// connection_id: ConnectionId,
|
||||
// peer: PeerId,
|
||||
// addr: &Multiaddr,
|
||||
// role_override: Endpoint,
|
||||
// port_use: PortUse,
|
||||
// ) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
// Ok(ConnectionHandler::select(
|
||||
// self.mdns.handle_established_outbound_connection(
|
||||
// connection_id,
|
||||
// peer,
|
||||
// addr,
|
||||
// role_override,
|
||||
// port_use,
|
||||
// )?,
|
||||
// self.gossipsub.handle_established_outbound_connection(
|
||||
// connection_id,
|
||||
// peer,
|
||||
// addr,
|
||||
// role_override,
|
||||
// port_use,
|
||||
// )?,
|
||||
// ))
|
||||
// }
|
||||
// fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
// self.mdns.on_swarm_event(event);
|
||||
// self.gossipsub.on_swarm_event(event);
|
||||
// }
|
||||
// fn on_connection_handler_event(
|
||||
// &mut self,
|
||||
// peer_id: PeerId,
|
||||
// connection_id: ConnectionId,
|
||||
// event: THandlerOutEvent<Self>,
|
||||
// ) {
|
||||
// match event {
|
||||
// Either::Left(ev) => NetworkBehaviour::on_connection_handler_event(
|
||||
// &mut self.mdns,
|
||||
// peer_id,
|
||||
// connection_id,
|
||||
// ev,
|
||||
// ),
|
||||
// Either::Right(ev) => NetworkBehaviour::on_connection_handler_event(
|
||||
// &mut self.gossipsub,
|
||||
// peer_id,
|
||||
// connection_id,
|
||||
// ev,
|
||||
// ),
|
||||
// }
|
||||
// }
|
||||
// fn poll(
|
||||
// &mut self,
|
||||
// cx: &mut std::task::Context,
|
||||
// ) -> std::task::Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// match NetworkBehaviour::poll(&mut self.mdns, cx) {
|
||||
// std::task::Poll::Ready(e) => {
|
||||
// return std::task::Poll::Ready(
|
||||
// e.map_out(DiscoveryBehaviourEvent::Mdns)
|
||||
// .map_in(|event| Either::Left(event)),
|
||||
// );
|
||||
// }
|
||||
// std::task::Poll::Pending => {}
|
||||
// }
|
||||
// match NetworkBehaviour::poll(&mut self.gossipsub, cx) {
|
||||
// std::task::Poll::Ready(e) => {
|
||||
// return std::task::Poll::Ready(
|
||||
// e.map_out(DiscoveryBehaviourEvent::Gossipsub)
|
||||
// .map_in(|event| Either::Right(event)),
|
||||
// );
|
||||
// }
|
||||
// std::task::Poll::Pending => {}
|
||||
// }
|
||||
// std::task::Poll::Pending
|
||||
// }
|
||||
// }
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> AnyResult<mdns::tokio::Behaviour> {
|
||||
use mdns::{tokio, Config};
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
|
||||
@@ -110,6 +110,16 @@ class Multiaddr:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes:bytes) -> Multiaddr:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string:builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
TODO: documentation
|
||||
@@ -122,8 +132,10 @@ class Multiaddr:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
def to_string(self) -> builtins.str:
|
||||
r"""
|
||||
TODO: documentation
|
||||
"""
|
||||
|
||||
class PeerId:
|
||||
r"""
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::ext::ResultExt;
|
||||
use crate::pylibp2p::connection::PyConnectionId;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use crate::pylibp2p::multiaddr::PyMultiaddr;
|
||||
use crate::{alias, pyclass, MPSC_CHANNEL_SIZE};
|
||||
use crate::{MPSC_CHANNEL_SIZE, alias, pyclass};
|
||||
use discovery::behaviour::{DiscoveryBehaviour, DiscoveryBehaviourEvent};
|
||||
use discovery::discovery_swarm;
|
||||
use libp2p::core::ConnectedPoint;
|
||||
@@ -17,9 +17,9 @@ use libp2p::futures::StreamExt;
|
||||
use libp2p::multiaddr::multiaddr;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{ConnectionId, SwarmEvent, ToSwarm};
|
||||
use libp2p::{gossipsub, mdns, Multiaddr, PeerId, Swarm};
|
||||
use libp2p::{Multiaddr, PeerId, Swarm, gossipsub, mdns};
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{pymethods, Bound, Py, PyObject, PyResult, PyTraverseError, PyVisit, Python};
|
||||
use pyo3::{Bound, Py, PyObject, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::convert::identity;
|
||||
use std::error::Error;
|
||||
@@ -274,7 +274,10 @@ impl PyDiscoveryService {
|
||||
#[allow(clippy::expect_used)]
|
||||
fn add_connected_callback<'py>(
|
||||
&self,
|
||||
#[override_type(type_repr="collections.abc.Callable[[ConnectionUpdate], None]", imports=("collections.abc"))]
|
||||
#[gen_stub(override_type(
|
||||
type_repr="collections.abc.Callable[[ConnectionUpdate], None]",
|
||||
imports=("collections.abc")
|
||||
))]
|
||||
callback: PyObject,
|
||||
) -> PyResult<()> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
@@ -304,7 +307,10 @@ impl PyDiscoveryService {
|
||||
#[allow(clippy::expect_used)]
|
||||
fn add_disconnected_callback<'py>(
|
||||
&self,
|
||||
#[override_type(type_repr="collections.abc.Callable[[ConnectionUpdate], None]", imports=("collections.abc"))]
|
||||
#[gen_stub(override_type(
|
||||
type_repr="collections.abc.Callable[[ConnectionUpdate], None]",
|
||||
imports=("collections.abc")
|
||||
))]
|
||||
callback: PyObject,
|
||||
) -> PyResult<()> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use crate::ext::ResultExt;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods};
|
||||
use pyo3::prelude::{PyBytesMethods, PyModule, PyModuleMethods};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr;
|
||||
|
||||
/// TODO: documentation...
|
||||
#[gen_stub_pyclass]
|
||||
@@ -27,6 +29,19 @@ impl PyMultiaddr {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// TODO: documentation
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// TODO: documentation
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// TODO: documentation
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
@@ -43,12 +58,19 @@ impl PyMultiaddr {
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// TODO: documentation
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.0.to_string()
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import multiprocessing.queues
|
||||
import pickle
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from exo_pyo3_bindings import ConnectionUpdate, Keypair, DiscoveryService
|
||||
@@ -49,43 +44,86 @@ async def test_discovery_callbacks() -> None:
|
||||
ident = Keypair.generate_ed25519()
|
||||
|
||||
service = DiscoveryService(ident)
|
||||
service.add_connected_callback(add_connected_callback)
|
||||
service.add_disconnected_callback(disconnected_callback)
|
||||
a = _add_connected_callback(service)
|
||||
d = _add_disconnected_callback(service)
|
||||
|
||||
for i in range(0, 1):
|
||||
# stream_get_a, stream_put = _make_iter()
|
||||
# service.add_connected_callback(stream_put)
|
||||
#
|
||||
# stream_get_d, stream_put = _make_iter()
|
||||
# service.add_disconnected_callback(stream_put)
|
||||
|
||||
# async for c in stream_get_a:
|
||||
# await connected_callback(c)
|
||||
|
||||
for i in range(0, 10):
|
||||
print(f"PYTHON: tick {i} of 10")
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
pass
|
||||
print(service, a, d) # only done to prevent GC... TODO: come up with less hacky solution
|
||||
|
||||
|
||||
def add_connected_callback(e: ConnectionUpdate) -> None:
|
||||
def _add_connected_callback(d: DiscoveryService):
|
||||
stream_get, stream_put = _make_iter()
|
||||
d.add_connected_callback(stream_put)
|
||||
|
||||
async def run():
|
||||
async for c in stream_get:
|
||||
await connected_callback(c)
|
||||
|
||||
return asyncio.create_task(run())
|
||||
|
||||
|
||||
def _add_disconnected_callback(d: DiscoveryService):
|
||||
stream_get, stream_put = _make_iter()
|
||||
|
||||
async def run():
|
||||
async for c in stream_get:
|
||||
await disconnected_callback(c)
|
||||
|
||||
d.add_disconnected_callback(stream_put)
|
||||
return asyncio.create_task(run())
|
||||
|
||||
|
||||
async def connected_callback(e: ConnectionUpdate) -> None:
|
||||
print(f"\n\nPYTHON: Connected callback: {e.peer_id}, {e.connection_id}, {e.local_addr}, {e.send_back_addr}")
|
||||
print(
|
||||
f"PYTHON: Connected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n")
|
||||
|
||||
|
||||
def disconnected_callback(e: ConnectionUpdate) -> None:
|
||||
async def disconnected_callback(e: ConnectionUpdate) -> None:
|
||||
print(f"\n\nPYTHON: Disconnected callback: {e.peer_id}, {e.connection_id}, {e.local_addr}, {e.send_back_addr}")
|
||||
print(
|
||||
f"PYTHON: Disconnected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n")
|
||||
|
||||
|
||||
# async def foobar(a: Callable[[str], Awaitable[str]]):
|
||||
# abc = await a("")
|
||||
# pass
|
||||
def _foo_task() -> None:
|
||||
print("PYTHON: This simply runs in asyncio context")
|
||||
|
||||
# def test_keypair_pickling() -> None:
|
||||
# def subprocess_task(kp: Keypair, q: multiprocessing.queues.Queue[Keypair]):
|
||||
# logging.info("a")
|
||||
# assert q.get() == kp
|
||||
# logging.info("b")
|
||||
|
||||
def _make_iter():
|
||||
loop = asyncio.get_event_loop()
|
||||
queue: asyncio.Queue[ConnectionUpdate] = asyncio.Queue()
|
||||
|
||||
def put(c: ConnectionUpdate) -> None:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, c)
|
||||
|
||||
async def get():
|
||||
while True:
|
||||
yield await queue.get()
|
||||
|
||||
return get(), put
|
||||
|
||||
# async def inputstream_generator(channels=1, **kwargs):
|
||||
# """Generator that yields blocks of input data as NumPy arrays."""
|
||||
# q_in = asyncio.Queue()
|
||||
# loop = asyncio.get_event_loop()
|
||||
#
|
||||
# def callback(indata, frame_count, time_info, status):
|
||||
# loop.call_soon_threadsafe(q_in.put_nowait, (indata.copy(), status))
|
||||
#
|
||||
# kp = Keypair.generate_ed25519()
|
||||
# q: multiprocessing.queues.Queue[Keypair] = multiprocessing.Queue()
|
||||
#
|
||||
# p = multiprocessing.Process(target=subprocess_task, args=(kp, q))
|
||||
# p.start()
|
||||
# q.put(kp)
|
||||
# p.join()
|
||||
# stream = sd.InputStream(callback=callback, channels=channels, **kwargs)
|
||||
# with stream:
|
||||
# while True:
|
||||
# indata, status = await q_in.get()
|
||||
# yield indata, status
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from exo_pyo3_bindings import Keypair
|
||||
from filelock import FileLock
|
||||
|
||||
from shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
|
||||
"""
|
||||
This file is responsible for concurrent race-free persistent node-ID retrieval.
|
||||
"""
|
||||
|
||||
|
||||
def _lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
|
||||
def get_node_id_keypair(path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR) -> Keypair:
|
||||
"""
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
|
||||
# operate with cross-process lock to avoid race conditions
|
||||
with FileLock(_lock_path(path)):
|
||||
with open(path, 'a+b') as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
except RuntimeError as e: # on runtime error, assume corrupt file
|
||||
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, 'w+b') as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
return keypair
|
||||
@@ -14,7 +14,7 @@ from typing import Optional
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from shared.node_id import get_node_id_keypair
|
||||
from shared.utils import get_node_id_keypair
|
||||
|
||||
NUM_CONCURRENT_PROCS = 10
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = Connection(
|
||||
source_node_id=node_a,
|
||||
sink_node_id=node_b,
|
||||
source_multiaddr="/ip4/127.0.0.1/tcp/10000",
|
||||
sink_multiaddr="/ip4/127.0.0.1/tcp/10001",
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
local_multiaddr="/ip4/127.0.0.1/tcp/10000",
|
||||
send_back_multiaddr="/ip4/127.0.0.1/tcp/10001",
|
||||
)
|
||||
|
||||
state = State()
|
||||
@@ -27,4 +27,4 @@ def test_state_serialization_roundtrip() -> None:
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
|
||||
assert restored_state.model_dump_json() == json_repr
|
||||
assert restored_state.model_dump_json() == json_repr
|
||||
|
||||
@@ -63,18 +63,17 @@ class Topology(TopologyProto):
|
||||
self._node_id_to_rx_id_map[node.node_id] = rx_id
|
||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
||||
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
connection: Connection,
|
||||
self,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection.source_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(Node(node_id=connection.source_node_id))
|
||||
if connection.sink_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(Node(node_id=connection.sink_node_id))
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(Node(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(Node(node_id=connection.send_back_node_id))
|
||||
|
||||
src_id = self._node_id_to_rx_id_map[connection.source_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.sink_node_id]
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
@@ -89,15 +88,15 @@ class Topology(TopologyProto):
|
||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
return self._graph.get_node_data(rx_idx).node_profile
|
||||
|
||||
|
||||
def update_node_profile(self, node_id: NodeId, node_profile: NodePerformanceProfile) -> None:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph[rx_idx].node_profile = node_profile
|
||||
|
||||
|
||||
def update_connection_profile(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.update_edge_by_index(rx_idx, connection)
|
||||
|
||||
|
||||
def get_connection_profile(self, connection: Connection) -> ConnectionProfile | None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
||||
@@ -112,7 +111,7 @@ class Topology(TopologyProto):
|
||||
def remove_connection(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
if self._is_bridge(connection):
|
||||
orphan_node_ids = self._get_orphan_node_ids(connection.source_node_id, connection)
|
||||
orphan_node_ids = self._get_orphan_node_ids(connection.local_node_id, connection)
|
||||
for orphan_node_id in orphan_node_ids:
|
||||
orphan_node_rx_id = self._node_id_to_rx_id_map[orphan_node_id]
|
||||
self._graph.remove_node(orphan_node_rx_id)
|
||||
@@ -122,16 +121,16 @@ class Topology(TopologyProto):
|
||||
self._graph.remove_edge_from_index(rx_idx)
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
|
||||
def get_cycles(self) -> list[list[Node]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
cycles: list[list[Node]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
|
||||
return cycles
|
||||
|
||||
|
||||
def _is_bridge(self, connection: Connection) -> bool:
|
||||
edge_idx = self._edge_id_to_rx_id_map[connection]
|
||||
graph_copy = self._graph.copy().to_undirected()
|
||||
@@ -141,17 +140,17 @@ class Topology(TopologyProto):
|
||||
components_after = rx.number_connected_components(graph_copy)
|
||||
|
||||
return components_after > components_before
|
||||
|
||||
|
||||
def _get_orphan_node_ids(self, master_node_id: NodeId, connection: Connection) -> list[NodeId]:
|
||||
edge_idx = self._edge_id_to_rx_id_map[connection]
|
||||
graph_copy = self._graph.copy().to_undirected()
|
||||
graph_copy.remove_edge_from_index(edge_idx)
|
||||
components = rx.connected_components(graph_copy)
|
||||
|
||||
orphan_node_rx_ids: set[int] = set()
|
||||
|
||||
orphan_node_rx_ids: set[int] = set()
|
||||
master_node_rx_id = self._node_id_to_rx_id_map[master_node_id]
|
||||
for component in components:
|
||||
if master_node_rx_id not in component:
|
||||
orphan_node_rx_ids.update(component)
|
||||
|
||||
|
||||
return [self._rx_id_to_node_id_map[rx_id] for rx_id in orphan_node_rx_ids]
|
||||
|
||||
@@ -11,15 +11,17 @@ class ID(str):
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source: type[Any],
|
||||
handler: GetCoreSchemaHandler
|
||||
cls,
|
||||
_source: type[Any],
|
||||
handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
# Re‑use the already‑defined schema for `str`
|
||||
return handler.generate_schema(str)
|
||||
|
||||
|
||||
class NodeId(ID):
|
||||
pass
|
||||
|
||||
|
||||
class CommandId(ID):
|
||||
pass
|
||||
|
||||
@@ -177,6 +177,10 @@ class TopologyEdgeCreated(_BaseEvent[_EventType.TopologyEdgeCreated]):
|
||||
|
||||
|
||||
class TopologyEdgeReplacedAtomically(_BaseEvent[_EventType.TopologyEdgeReplacedAtomically]):
|
||||
"""
|
||||
TODO: delete this????
|
||||
"""
|
||||
|
||||
event_type: Literal[_EventType.TopologyEdgeReplacedAtomically] = _EventType.TopologyEdgeReplacedAtomically
|
||||
edge: Connection
|
||||
edge_profile: ConnectionProfile
|
||||
@@ -186,6 +190,7 @@ class TopologyEdgeDeleted(_BaseEvent[_EventType.TopologyEdgeDeleted]):
|
||||
event_type: Literal[_EventType.TopologyEdgeDeleted] = _EventType.TopologyEdgeDeleted
|
||||
edge: Connection
|
||||
|
||||
|
||||
_Event = Union[
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
@@ -263,8 +268,6 @@ def _check_event_type_consistency():
|
||||
|
||||
_check_event_type_consistency()
|
||||
|
||||
|
||||
|
||||
Event = Annotated[_Event, Field(discriminator="event_type")]
|
||||
"""Type of events, a discriminated union."""
|
||||
|
||||
@@ -276,4 +279,4 @@ Event = Annotated[_Event, Field(discriminator="event_type")]
|
||||
#
|
||||
# class TimerFired(_BaseEvent[_EventType.TimerFired]):
|
||||
# event_type: Literal[_EventType.TimerFired] = _EventType.TimerFired
|
||||
# timer_id: TimerId
|
||||
# timer_id: TimerId
|
||||
|
||||
@@ -7,31 +7,33 @@ from shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
|
||||
|
||||
class Connection(BaseModel):
|
||||
source_node_id: NodeId
|
||||
sink_node_id: NodeId
|
||||
source_multiaddr: str
|
||||
sink_multiaddr: str
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
local_multiaddr: str
|
||||
send_back_multiaddr: str
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
# required for Connection to be used as a key
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.source_node_id,
|
||||
self.sink_node_id,
|
||||
self.source_multiaddr,
|
||||
self.sink_multiaddr,
|
||||
)
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.local_multiaddr,
|
||||
self.send_back_multiaddr,
|
||||
)
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.source_node_id == other.source_node_id
|
||||
and self.sink_node_id == other.sink_node_id
|
||||
and self.source_multiaddr == other.source_multiaddr
|
||||
and self.sink_multiaddr == other.sink_multiaddr
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.local_multiaddr == other.local_multiaddr
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
|
||||
|
||||
@@ -44,8 +46,8 @@ class TopologyProto(Protocol):
|
||||
def add_node(self, node: Node) -> None: ...
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
connection: Connection,
|
||||
self,
|
||||
connection: Connection,
|
||||
) -> None: ...
|
||||
|
||||
def list_nodes(self) -> Iterable[Node]: ...
|
||||
|
||||
@@ -1,7 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Type
|
||||
|
||||
from exo_pyo3_bindings import Keypair
|
||||
from filelock import FileLock
|
||||
|
||||
from shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
|
||||
|
||||
def ensure_type[T](obj: Any, expected_type: Type[T]) -> T: # type: ignore
|
||||
if not isinstance(obj, expected_type):
|
||||
raise TypeError(f"Expected {expected_type}, got {type(obj)}") # type: ignore
|
||||
return obj
|
||||
|
||||
|
||||
# def make_async_iter[T]():
|
||||
# """
|
||||
# Creates a pair `<async-iter>, <put-to-iter>` of an asynchronous iterator
|
||||
# and a synchronous function to put items into that iterator.
|
||||
# """
|
||||
#
|
||||
# loop = asyncio.get_event_loop()
|
||||
# queue: asyncio.Queue[T] = asyncio.Queue()
|
||||
#
|
||||
# def put(c: ConnectionUpdate) -> None:
|
||||
# loop.call_soon_threadsafe(queue.put_nowait, (c,))
|
||||
#
|
||||
# async def get():
|
||||
# while True:
|
||||
# yield await queue.get()
|
||||
#
|
||||
# return get(), put
|
||||
|
||||
def get_node_id_keypair(path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR) -> Keypair:
|
||||
"""
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
|
||||
def lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
# operate with cross-process lock to avoid race conditions
|
||||
with FileLock(lock_path(path)):
|
||||
with open(path, 'a+b') as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
except RuntimeError as e: # on runtime error, assume corrupt file
|
||||
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, 'w+b') as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
return keypair
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel, ConfigDict
|
||||
from shared.apply import apply
|
||||
from shared.db.sqlite import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
||||
from shared.node_id import get_node_id_keypair
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -54,6 +53,7 @@ from shared.types.worker.runners import (
|
||||
RunningRunnerStatus,
|
||||
)
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
from shared.utils import get_node_id_keypair
|
||||
from worker.download.download_utils import build_model_path
|
||||
from worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from worker.utils.profile import start_polling_node_metrics
|
||||
@@ -226,7 +226,6 @@ class Worker:
|
||||
assigned_runner.status = ReadyRunnerStatus()
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
|
||||
async def _execute_task_op(
|
||||
self, op: ExecuteTaskOp
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
@@ -308,7 +307,6 @@ class Worker:
|
||||
# Ensure the task is cleaned up
|
||||
await task
|
||||
|
||||
|
||||
## Operation Planner
|
||||
|
||||
async def _execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
|
||||
@@ -474,7 +472,7 @@ class Worker:
|
||||
running_runner_count = 0
|
||||
for other_runner_id, other_runner_status in state.runners.items():
|
||||
if other_runner_id in instance.shard_assignments.node_to_runner.values() and \
|
||||
isinstance(other_runner_status, RunningRunnerStatus):
|
||||
isinstance(other_runner_status, RunningRunnerStatus):
|
||||
running_runner_count += 1
|
||||
|
||||
if running_runner_count == runner.shard_metadata.world_size - 1:
|
||||
|
||||
Reference in New Issue
Block a user