Discovery integration master

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
Andrei Cravtov
2025-07-27 15:43:59 +03:00
committed by GitHub
parent 98f204d14a
commit b687dec6b2
21 changed files with 655 additions and 199 deletions

12
.idea/exo-v2.iml generated
View File

@@ -1,5 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="EMPTY_MODULE" version="4">
<component name="FacetManager">
<facet type="Python" name="Python facet">
<configuration sdkName="Python 3.13 virtualenv at ~/Desktop/exo/.venv" />
</facet>
</component>
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/rust/discovery/src" isTestSource="false" />
@@ -11,10 +16,17 @@
<sourceFolder url="file://$MODULE_DIR$/rust/util/fn_pipe/proc/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/rust/util/fn_pipe/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/rust/util/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/engines/mlx" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/master" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/shared" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/worker" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/.venv" />
<excludeFolder url="file://$MODULE_DIR$/rust/target" />
<excludeFolder url="file://$MODULE_DIR$/.direnv" />
<excludeFolder url="file://$MODULE_DIR$/build" />
</content>
<orderEntry type="jdk" jdkName="Python 3.13 (exo)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="Python 3.13 virtualenv at ~/Desktop/exo/.venv interpreter library" level="application" />
</component>
</module>

View File

@@ -0,0 +1,132 @@
import asyncio
import logging
from exo_pyo3_bindings import ConnectionUpdate, DiscoveryService, Keypair
from shared.db import AsyncSQLiteEventStorage
from shared.types.common import NodeId
from shared.types.events import TopologyEdgeCreated, TopologyEdgeDeleted
from shared.types.topology import Connection
class DiscoverySupervisor:
def __init__(self, node_id_keypair: Keypair, node_id: NodeId, global_events: AsyncSQLiteEventStorage,
logger: logging.Logger):
self.global_events = global_events
self.logger = logger
self.node_id = node_id
# configure callbacks
self.discovery_service = DiscoveryService(node_id_keypair)
self._add_connected_callback()
self._add_disconnected_callback()
def _add_connected_callback(self):
stream_get, stream_put = _make_iter()
self.discovery_service.add_connected_callback(stream_put)
async def run():
async for c in stream_get:
await self._connected_callback(c)
return asyncio.create_task(run())
def _add_disconnected_callback(self):
stream_get, stream_put = _make_iter()
async def run():
async for c in stream_get:
await self._disconnected_callback(c)
self.discovery_service.add_disconnected_callback(stream_put)
return asyncio.create_task(run())
async def _connected_callback(self, e: ConnectionUpdate) -> None:
local_node_id = self.node_id
send_back_node_id = NodeId(e.peer_id.to_base58())
local_multiaddr = e.local_addr.to_string()
send_back_multiaddr = e.send_back_addr.to_string()
connection_profile = None
topology_edge_created = TopologyEdgeCreated(edge=Connection(
local_node_id=local_node_id,
send_back_node_id=send_back_node_id,
local_multiaddr=local_multiaddr,
send_back_multiaddr=send_back_multiaddr,
connection_profile=connection_profile
))
self.logger.error(
msg=f"CONNECTED CALLBACK: {local_node_id} -> {send_back_node_id}, {local_multiaddr} -> {send_back_multiaddr}")
await self.global_events.append_events(
[topology_edge_created],
self.node_id
)
async def _disconnected_callback(self, e: ConnectionUpdate) -> None:
local_node_id = self.node_id
send_back_node_id = NodeId(e.peer_id.to_base58())
local_multiaddr = e.local_addr.to_string()
send_back_multiaddr = e.send_back_addr.to_string()
connection_profile = None
topology_edge_created = TopologyEdgeDeleted(edge=Connection(
local_node_id=local_node_id,
send_back_node_id=send_back_node_id,
local_multiaddr=local_multiaddr,
send_back_multiaddr=send_back_multiaddr,
connection_profile=connection_profile
))
self.logger.error(
msg=f"DISCONNECTED CALLBACK: {local_node_id} -> {send_back_node_id}, {local_multiaddr} -> {send_back_multiaddr}")
await self.global_events.append_events(
[topology_edge_created],
self.node_id
)
def _make_iter(): # TODO: generalize to generic utility
loop = asyncio.get_event_loop()
queue: asyncio.Queue[ConnectionUpdate] = asyncio.Queue()
def put(c: ConnectionUpdate) -> None:
loop.call_soon_threadsafe(queue.put_nowait, c)
async def get():
while True:
yield await queue.get()
return get(), put
# class MyClass: # TODO: figure out how to make pydantic integrate with Multiaddr
# def __init__(self, data: str):
# self.data = data
#
# @staticmethod
# def from_str(s: str, _i: ValidationInfo) -> 'MyClass':
# return MyClass(s)
#
# def __str__(self):
# return self.data
#
# @classmethod
# def __get_pydantic_core_schema__(
# cls, source_type: type[any], handler: GetCoreSchemaHandler
# ) -> CoreSchema:
# return core_schema.with_info_after_validator_function(
# function=MyClass.from_str,
# schema=core_schema.bytes_schema(),
# serialization=core_schema.to_string_ser_schema()
# )
#
#
# # Use directly in a model (no Annotated needed)
# class ExampleModel(BaseModel):
# field: MyClass
#
#
# m = ExampleModel(field=MyClass("foo"))
# d = m.model_dump()
# djs = m.model_dump_json()
#
# print(d)
# print(djs)

View File

@@ -6,7 +6,10 @@ import traceback
from pathlib import Path
from typing import List
from exo_pyo3_bindings import Keypair
from master.api import start_fastapi_server
from master.discovery_supervisor import DiscoverySupervisor
from master.election_callback import ElectionCallbacks
from master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
from master.placement import get_instance_placements, get_transition_events
@@ -14,7 +17,6 @@ from shared.apply import apply
from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.node_id import get_node_id_keypair
from shared.types.common import NodeId
from shared.types.events import (
Event,
@@ -30,14 +32,23 @@ from shared.types.events.commands import (
from shared.types.state import State
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
from shared.types.worker.instances import Instance
from shared.utils import get_node_id_keypair
class Master:
def __init__(self, node_id: NodeId, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, worker_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: logging.Logger):
def __init__(self, node_id_keypair: Keypair, node_id: NodeId, command_buffer: list[Command],
global_events: AsyncSQLiteEventStorage, worker_events: AsyncSQLiteEventStorage,
forwarder_binary_path: Path, logger: logging.Logger):
self.node_id = node_id
self.command_buffer = command_buffer
self.global_events = global_events
self.worker_events = worker_events
self.discovery_supervisor = DiscoverySupervisor(
node_id_keypair,
node_id,
global_events,
logger
)
self.forwarder_supervisor = ForwarderSupervisor(
forwarder_binary_path=forwarder_binary_path,
logger=logger
@@ -128,7 +139,6 @@ class Master:
await asyncio.sleep(0.1)
async def main():
logger = logging.getLogger('master_logger')
logger.setLevel(logging.DEBUG)
@@ -163,8 +173,10 @@ async def main():
api_thread.start()
logger.info('Running FastAPI server in a separate thread. Listening on port 8000.')
master = Master(node_id, command_buffer, global_events, worker_events, forwarder_binary_path=Path("./build/forwarder"), logger=logger)
master = Master(node_id_keypair, node_id, command_buffer, global_events, worker_events,
forwarder_binary_path=Path("./build/forwarder"), logger=logger)
await master.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -15,17 +15,17 @@ def create_node():
if node_id is None:
node_id = NodeId()
return Node(
node_id=node_id,
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
model_id="test",
chip_id="test",
memory=MemoryPerformanceProfile(
ram_total=1000,
ram_available=memory,
swap_total=1000,
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000
),
network_interfaces=[],
),
network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=1000)
)
)
@@ -37,10 +37,11 @@ def create_node():
def create_connection():
def _create_connection(source_node_id: NodeId, sink_node_id: NodeId) -> Connection:
return Connection(
source_node_id=source_node_id,
sink_node_id=sink_node_id,
source_multiaddr="/ip4/127.0.0.1/tcp/1234",
sink_multiaddr="/ip4/127.0.0.1/tcp/1235",
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
local_multiaddr="/ip4/127.0.0.1/tcp/1234",
send_back_multiaddr="/ip4/127.0.0.1/tcp/1235",
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
)
return _create_connection
return _create_connection

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from typing import List
import pytest
from exo_pyo3_bindings import Keypair
from master.main import Master
from shared.db.sqlite.config import EventLogConfig
@@ -38,8 +39,10 @@ async def test_master():
forwarder_binary_path = _create_forwarder_dummy_binary()
node_id = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
master = Master(node_id, command_buffer=command_buffer, global_events=global_events, worker_events=event_log_manager.worker_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
node_id_keypair = Keypair.generate_ed25519()
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
master = Master(node_id_keypair, node_id, command_buffer=command_buffer, global_events=global_events,
forwarder_binary_path=forwarder_binary_path, logger=logger, worker_events=event_log_manager.worker_events)
asyncio.create_task(master.run())
command_buffer.append(

View File

@@ -13,20 +13,27 @@ from shared.types.topology import Connection, ConnectionProfile, Node, NodeId
def topology() -> Topology:
return Topology()
@pytest.fixture
def connection() -> Connection:
return Connection(source_node_id=NodeId(), sink_node_id=NodeId(), source_multiaddr="/ip4/127.0.0.1/tcp/1234", sink_multiaddr="/ip4/127.0.0.1/tcp/1235", connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000))
return Connection(local_node_id=NodeId(), send_back_node_id=NodeId(), local_multiaddr="/ip4/127.0.0.1/tcp/1234",
send_back_multiaddr="/ip4/127.0.0.1/tcp/1235",
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000))
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000)
system_profile = SystemPerformanceProfile(flops_fp16=1000)
return NodePerformanceProfile(model_id="test", chip_id="test", memory=memory_profile, network_interfaces=[], system=system_profile)
return NodePerformanceProfile(model_id="test", chip_id="test", memory=memory_profile, network_interfaces=[],
system=system_profile)
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
@@ -41,39 +48,47 @@ def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
def test_add_connection(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
# arrange
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
topology.add_connection(connection)
# act
data = topology.get_connection_profile(connection)
# assert
assert data == connection.connection_profile
assert data == connection.connection_profile
def test_update_node_profile(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
# arrange
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test", memory=MemoryPerformanceProfile(ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000), network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
new_node_profile = NodePerformanceProfile(model_id="test", chip_id="test",
memory=MemoryPerformanceProfile(ram_total=1000, ram_available=1000,
swap_total=1000, swap_available=1000),
network_interfaces=[], system=SystemPerformanceProfile(flops_fp16=1000))
# act
topology.update_node_profile(connection.source_node_id, node_profile=new_node_profile)
topology.update_node_profile(connection.local_node_id, node_profile=new_node_profile)
# assert
data = topology.get_node_profile(connection.source_node_id)
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
# arrange
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(throughput=2000, latency=2000, jitter=2000)
connection = Connection(source_node_id=connection.source_node_id, sink_node_id=connection.sink_node_id, source_multiaddr=connection.source_multiaddr, sink_multiaddr=connection.sink_multiaddr, connection_profile=new_connection_profile)
connection = Connection(local_node_id=connection.local_node_id, send_back_node_id=connection.send_back_node_id,
local_multiaddr=connection.local_multiaddr,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile)
# act
topology.update_connection_profile(connection)
@@ -82,10 +97,12 @@ def test_update_connection_profile(topology: Topology, node_profile: NodePerform
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
def test_remove_connection_still_connected(topology: Topology, node_profile: NodePerformanceProfile,
connection: Connection):
# arrange
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
topology.add_connection(connection)
# act
@@ -94,7 +111,8 @@ def test_remove_connection_still_connected(topology: Topology, node_profile: Nod
# assert
with pytest.raises(IndexError):
topology.get_connection_profile(connection)
def test_remove_connection_bridge(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
"""Create a bridge scenario: master -> node_a -> node_b
and remove the bridge connection (master -> node_a)"""
@@ -102,63 +120,63 @@ def test_remove_connection_bridge(topology: Topology, node_profile: NodePerforma
master_id = NodeId()
node_a_id = NodeId()
node_b_id = NodeId()
topology.add_node(Node(node_id=master_id, node_profile=node_profile))
topology.add_node(Node(node_id=node_a_id, node_profile=node_profile))
topology.add_node(Node(node_id=node_b_id, node_profile=node_profile))
connection_master_to_a = Connection(
source_node_id=master_id,
sink_node_id=node_a_id,
source_multiaddr="/ip4/127.0.0.1/tcp/1234",
sink_multiaddr="/ip4/127.0.0.1/tcp/1235",
local_node_id=master_id,
send_back_node_id=node_a_id,
local_multiaddr="/ip4/127.0.0.1/tcp/1234",
send_back_multiaddr="/ip4/127.0.0.1/tcp/1235",
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
)
connection_a_to_b = Connection(
source_node_id=node_a_id,
sink_node_id=node_b_id,
source_multiaddr="/ip4/127.0.0.1/tcp/1236",
sink_multiaddr="/ip4/127.0.0.1/tcp/1237",
local_node_id=node_a_id,
send_back_node_id=node_b_id,
local_multiaddr="/ip4/127.0.0.1/tcp/1236",
send_back_multiaddr="/ip4/127.0.0.1/tcp/1237",
connection_profile=ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
)
topology.add_connection(connection_master_to_a)
topology.add_connection(connection_a_to_b)
assert len(list(topology.list_nodes())) == 3
topology.remove_connection(connection_master_to_a)
remaining_nodes = list(topology.list_nodes())
assert len(remaining_nodes) == 1
assert remaining_nodes[0].node_id == master_id
with pytest.raises(KeyError):
topology.get_node_profile(node_a_id)
with pytest.raises(KeyError):
topology.get_node_profile(node_b_id)
def test_remove_node_still_connected(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
# arrange
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
topology.add_connection(connection)
# act
topology.remove_node(connection.source_node_id)
topology.remove_node(connection.local_node_id)
# assert
with pytest.raises(KeyError):
topology.get_node_profile(connection.source_node_id)
topology.get_node_profile(connection.local_node_id)
def test_list_nodes(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
# arrange
topology.add_node(Node(node_id=connection.source_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.sink_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(Node(node_id=connection.send_back_node_id, node_profile=node_profile))
topology.add_connection(connection)
# act
@@ -167,4 +185,4 @@ def test_list_nodes(topology: Topology, node_profile: NodePerformanceProfile, co
# assert
assert len(nodes) == 2
assert all(isinstance(node, Node) for node in nodes)
assert {node.node_id for node in nodes} == {connection.source_node_id, connection.sink_node_id}
assert {node.node_id for node in nodes} == {connection.local_node_id, connection.send_back_node_id}

View File

@@ -76,7 +76,7 @@ libp2p-tcp = "0.44"
# interop
pyo3 = "0.25"
#pyo3-stub-gen = { git = "https://github.com/Jij-Inc/pyo3-stub-gen.git", rev = "d2626600e52452e71095c57e721514de748d419d" } # v0.11 not yet published to crates
pyo3-stub-gen = { git = "https://github.com/cstruct/pyo3-stub-gen.git", rev = "2efddde7dcffc462868aa0e4bbc46877c657a0fe" } # This fork adds support for type overrides => not merged yet!!!
pyo3-stub-gen = { git = "https://github.com/cstruct/pyo3-stub-gen.git", rev = "a935099276fa2d273496a2759d4af7177a6acd57" } # This fork adds support for type overrides => not merged yet!!!
pyo3-async-runtimes = "0.25"
[workspace.lints.rust]

View File

@@ -1,6 +1,14 @@
use crate::alias::AnyResult;
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity, mdns};
use libp2p::core::Endpoint;
use libp2p::core::transport::PortUse;
use libp2p::swarm::derive_prelude::Either;
use libp2p::swarm::{
ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm,
NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
};
use libp2p::{Multiaddr, PeerId, gossipsub, identity, mdns};
use std::fmt;
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::time::Duration;
@@ -12,8 +20,183 @@ pub struct DiscoveryBehaviour {
pub gossipsub: gossipsub::Behaviour,
}
// #[doc = "`NetworkBehaviour::ToSwarm` produced by DiscoveryBehaviour."]
// pub enum DiscoveryBehaviourEvent {
// Mdns(<mdns::tokio::Behaviour as NetworkBehaviour>::ToSwarm),
// Gossipsub(<gossipsub::Behaviour as NetworkBehaviour>::ToSwarm),
// }
// impl Debug for DiscoveryBehaviourEvent
// where
// <mdns::tokio::Behaviour as NetworkBehaviour>::ToSwarm: Debug,
// <gossipsub::Behaviour as NetworkBehaviour>::ToSwarm: Debug,
// {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
// match &self {
// DiscoveryBehaviourEvent::Mdns(event) => {
// f.write_fmt(format_args!("{}: {:?}", "DiscoveryBehaviourEvent", event))
// }
// DiscoveryBehaviourEvent::Gossipsub(event) => {
// f.write_fmt(format_args!("{}: {:?}", "DiscoveryBehaviourEvent", event))
// }
// }
// }
// }
// impl NetworkBehaviour for DiscoveryBehaviour
// where
// mdns::tokio::Behaviour: NetworkBehaviour,
// gossipsub::Behaviour: NetworkBehaviour,
// {
// type ConnectionHandler =
// ConnectionHandlerSelect<THandler<mdns::tokio::Behaviour>, THandler<gossipsub::Behaviour>>;
// type ToSwarm = DiscoveryBehaviourEvent;
// #[allow(clippy::needless_question_mark)]
// fn handle_pending_inbound_connection(
// &mut self,
// connection_id: ConnectionId,
// local_addr: &Multiaddr,
// remote_addr: &Multiaddr,
// ) -> Result<(), ConnectionDenied> {
// NetworkBehaviour::handle_pending_inbound_connection(
// &mut self.mdns,
// connection_id,
// local_addr,
// remote_addr,
// )?;
// NetworkBehaviour::handle_pending_inbound_connection(
// &mut self.gossipsub,
// connection_id,
// local_addr,
// remote_addr,
// )?;
// Ok(())
// }
// #[allow(clippy::needless_question_mark)]
// fn handle_established_inbound_connection(
// &mut self,
// connection_id: ConnectionId,
// peer: PeerId,
// local_addr: &Multiaddr,
// remote_addr: &Multiaddr,
// ) -> Result<THandler<Self>, ConnectionDenied> {
// Ok(ConnectionHandler::select(
// self.mdns.handle_established_inbound_connection(
// connection_id,
// peer,
// local_addr,
// remote_addr,
// )?,
// self.gossipsub.handle_established_inbound_connection(
// connection_id,
// peer,
// local_addr,
// remote_addr,
// )?,
// ))
// }
// #[allow(clippy::needless_question_mark)]
// fn handle_pending_outbound_connection(
// &mut self,
// connection_id: ConnectionId,
// maybe_peer: Option<PeerId>,
// addresses: &[Multiaddr],
// effective_role: Endpoint,
// ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
// let mut combined_addresses = Vec::new();
// combined_addresses.extend(NetworkBehaviour::handle_pending_outbound_connection(
// &mut self.mdns,
// connection_id,
// maybe_peer,
// addresses,
// effective_role,
// )?);
// combined_addresses.extend(NetworkBehaviour::handle_pending_outbound_connection(
// &mut self.gossipsub,
// connection_id,
// maybe_peer,
// addresses,
// effective_role,
// )?);
// Ok(combined_addresses)
// }
// #[allow(clippy::needless_question_mark)]
// fn handle_established_outbound_connection(
// &mut self,
// connection_id: ConnectionId,
// peer: PeerId,
// addr: &Multiaddr,
// role_override: Endpoint,
// port_use: PortUse,
// ) -> Result<THandler<Self>, ConnectionDenied> {
// Ok(ConnectionHandler::select(
// self.mdns.handle_established_outbound_connection(
// connection_id,
// peer,
// addr,
// role_override,
// port_use,
// )?,
// self.gossipsub.handle_established_outbound_connection(
// connection_id,
// peer,
// addr,
// role_override,
// port_use,
// )?,
// ))
// }
// fn on_swarm_event(&mut self, event: FromSwarm) {
// self.mdns.on_swarm_event(event);
// self.gossipsub.on_swarm_event(event);
// }
// fn on_connection_handler_event(
// &mut self,
// peer_id: PeerId,
// connection_id: ConnectionId,
// event: THandlerOutEvent<Self>,
// ) {
// match event {
// Either::Left(ev) => NetworkBehaviour::on_connection_handler_event(
// &mut self.mdns,
// peer_id,
// connection_id,
// ev,
// ),
// Either::Right(ev) => NetworkBehaviour::on_connection_handler_event(
// &mut self.gossipsub,
// peer_id,
// connection_id,
// ev,
// ),
// }
// }
// fn poll(
// &mut self,
// cx: &mut std::task::Context,
// ) -> std::task::Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
// match NetworkBehaviour::poll(&mut self.mdns, cx) {
// std::task::Poll::Ready(e) => {
// return std::task::Poll::Ready(
// e.map_out(DiscoveryBehaviourEvent::Mdns)
// .map_in(|event| Either::Left(event)),
// );
// }
// std::task::Poll::Pending => {}
// }
// match NetworkBehaviour::poll(&mut self.gossipsub, cx) {
// std::task::Poll::Ready(e) => {
// return std::task::Poll::Ready(
// e.map_out(DiscoveryBehaviourEvent::Gossipsub)
// .map_in(|event| Either::Right(event)),
// );
// }
// std::task::Poll::Pending => {}
// }
// std::task::Poll::Pending
// }
// }
fn mdns_behaviour(keypair: &identity::Keypair) -> AnyResult<mdns::tokio::Behaviour> {
use mdns::{tokio, Config};
use mdns::{Config, tokio};
// mDNS config => enable IPv6
let mdns_config = Config {

View File

@@ -110,6 +110,16 @@ class Multiaddr:
r"""
TODO: documentation
"""
@staticmethod
def from_bytes(bytes:bytes) -> Multiaddr:
r"""
TODO: documentation
"""
@staticmethod
def from_string(string:builtins.str) -> Multiaddr:
r"""
TODO: documentation
"""
def len(self) -> builtins.int:
r"""
TODO: documentation
@@ -122,8 +132,10 @@ class Multiaddr:
r"""
TODO: documentation
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
def to_string(self) -> builtins.str:
r"""
TODO: documentation
"""
class PeerId:
r"""

View File

@@ -9,7 +9,7 @@ use crate::ext::ResultExt;
use crate::pylibp2p::connection::PyConnectionId;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use crate::pylibp2p::multiaddr::PyMultiaddr;
use crate::{alias, pyclass, MPSC_CHANNEL_SIZE};
use crate::{MPSC_CHANNEL_SIZE, alias, pyclass};
use discovery::behaviour::{DiscoveryBehaviour, DiscoveryBehaviourEvent};
use discovery::discovery_swarm;
use libp2p::core::ConnectedPoint;
@@ -17,9 +17,9 @@ use libp2p::futures::StreamExt;
use libp2p::multiaddr::multiaddr;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{ConnectionId, SwarmEvent, ToSwarm};
use libp2p::{gossipsub, mdns, Multiaddr, PeerId, Swarm};
use libp2p::{Multiaddr, PeerId, Swarm, gossipsub, mdns};
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{pymethods, Bound, Py, PyObject, PyResult, PyTraverseError, PyVisit, Python};
use pyo3::{Bound, Py, PyObject, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::convert::identity;
use std::error::Error;
@@ -274,7 +274,10 @@ impl PyDiscoveryService {
#[allow(clippy::expect_used)]
fn add_connected_callback<'py>(
&self,
#[override_type(type_repr="collections.abc.Callable[[ConnectionUpdate], None]", imports=("collections.abc"))]
#[gen_stub(override_type(
type_repr="collections.abc.Callable[[ConnectionUpdate], None]",
imports=("collections.abc")
))]
callback: PyObject,
) -> PyResult<()> {
use pyo3_async_runtimes::tokio::get_runtime;
@@ -304,7 +307,10 @@ impl PyDiscoveryService {
#[allow(clippy::expect_used)]
fn add_disconnected_callback<'py>(
&self,
#[override_type(type_repr="collections.abc.Callable[[ConnectionUpdate], None]", imports=("collections.abc"))]
#[gen_stub(override_type(
type_repr="collections.abc.Callable[[ConnectionUpdate], None]",
imports=("collections.abc")
))]
callback: PyObject,
) -> PyResult<()> {
use pyo3_async_runtimes::tokio::get_runtime;

View File

@@ -1,8 +1,10 @@
use crate::ext::ResultExt;
use libp2p::Multiaddr;
use pyo3::prelude::{PyModule, PyModuleMethods};
use pyo3::prelude::{PyBytesMethods, PyModule, PyModuleMethods};
use pyo3::types::PyBytes;
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr;
/// TODO: documentation...
#[gen_stub_pyclass]
@@ -27,6 +29,19 @@ impl PyMultiaddr {
Self(Multiaddr::with_capacity(n))
}
/// TODO: documentation
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// TODO: documentation
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// TODO: documentation
fn len(&self) -> usize {
self.0.len()
@@ -43,12 +58,19 @@ impl PyMultiaddr {
PyBytes::new(py, &bytes)
}
/// TODO: documentation
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.0.to_string()
self.to_string()
}
}

View File

@@ -1,10 +1,5 @@
import logging
import multiprocessing
import multiprocessing.queues
import pickle
import asyncio
import time
from collections.abc import Awaitable
from typing import Callable
import pytest
from exo_pyo3_bindings import ConnectionUpdate, Keypair, DiscoveryService
@@ -49,43 +44,86 @@ async def test_discovery_callbacks() -> None:
ident = Keypair.generate_ed25519()
service = DiscoveryService(ident)
service.add_connected_callback(add_connected_callback)
service.add_disconnected_callback(disconnected_callback)
a = _add_connected_callback(service)
d = _add_disconnected_callback(service)
for i in range(0, 1):
# stream_get_a, stream_put = _make_iter()
# service.add_connected_callback(stream_put)
#
# stream_get_d, stream_put = _make_iter()
# service.add_disconnected_callback(stream_put)
# async for c in stream_get_a:
# await connected_callback(c)
for i in range(0, 10):
print(f"PYTHON: tick {i} of 10")
time.sleep(1)
await asyncio.sleep(1)
pass
print(service, a, d) # only done to prevent GC... TODO: come up with less hacky solution
def add_connected_callback(e: ConnectionUpdate) -> None:
def _add_connected_callback(d: DiscoveryService):
stream_get, stream_put = _make_iter()
d.add_connected_callback(stream_put)
async def run():
async for c in stream_get:
await connected_callback(c)
return asyncio.create_task(run())
def _add_disconnected_callback(d: DiscoveryService):
stream_get, stream_put = _make_iter()
async def run():
async for c in stream_get:
await disconnected_callback(c)
d.add_disconnected_callback(stream_put)
return asyncio.create_task(run())
async def connected_callback(e: ConnectionUpdate) -> None:
print(f"\n\nPYTHON: Connected callback: {e.peer_id}, {e.connection_id}, {e.local_addr}, {e.send_back_addr}")
print(
f"PYTHON: Connected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n")
def disconnected_callback(e: ConnectionUpdate) -> None:
async def disconnected_callback(e: ConnectionUpdate) -> None:
print(f"\n\nPYTHON: Disconnected callback: {e.peer_id}, {e.connection_id}, {e.local_addr}, {e.send_back_addr}")
print(
f"PYTHON: Disconnected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n")
# async def foobar(a: Callable[[str], Awaitable[str]]):
# abc = await a("")
# pass
def _foo_task() -> None:
print("PYTHON: This simply runs in asyncio context")
# def test_keypair_pickling() -> None:
# def subprocess_task(kp: Keypair, q: multiprocessing.queues.Queue[Keypair]):
# logging.info("a")
# assert q.get() == kp
# logging.info("b")
def _make_iter():
loop = asyncio.get_event_loop()
queue: asyncio.Queue[ConnectionUpdate] = asyncio.Queue()
def put(c: ConnectionUpdate) -> None:
loop.call_soon_threadsafe(queue.put_nowait, c)
async def get():
while True:
yield await queue.get()
return get(), put
# async def inputstream_generator(channels=1, **kwargs):
# """Generator that yields blocks of input data as NumPy arrays."""
# q_in = asyncio.Queue()
# loop = asyncio.get_event_loop()
#
# def callback(indata, frame_count, time_info, status):
# loop.call_soon_threadsafe(q_in.put_nowait, (indata.copy(), status))
#
# kp = Keypair.generate_ed25519()
# q: multiprocessing.queues.Queue[Keypair] = multiprocessing.Queue()
#
# p = multiprocessing.Process(target=subprocess_task, args=(kp, q))
# p.start()
# q.put(kp)
# p.join()
# stream = sd.InputStream(callback=callback, channels=channels, **kwargs)
# with stream:
# while True:
# indata, status = await q_in.get()
# yield indata, status

View File

@@ -1,44 +0,0 @@
from __future__ import annotations
import logging
import os
from pathlib import Path
from exo_pyo3_bindings import Keypair
from filelock import FileLock
from shared.constants import EXO_NODE_ID_KEYPAIR
"""
This file is responsible for concurrent race-free persistent node-ID retrieval.
"""
def _lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
def get_node_id_keypair(path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR) -> Keypair:
"""
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# operate with cross-process lock to avoid race conditions
with FileLock(_lock_path(path)):
with open(path, 'a+b') as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
except RuntimeError as e: # on runtime error, assume corrupt file
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, 'w+b') as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -14,7 +14,7 @@ from typing import Optional
from pytest import LogCaptureFixture
from shared.constants import EXO_NODE_ID_KEYPAIR
from shared.node_id import get_node_id_keypair
from shared.utils import get_node_id_keypair
NUM_CONCURRENT_PROCS = 10

View File

@@ -13,10 +13,10 @@ def test_state_serialization_roundtrip() -> None:
node_b = NodeId("node-b")
connection = Connection(
source_node_id=node_a,
sink_node_id=node_b,
source_multiaddr="/ip4/127.0.0.1/tcp/10000",
sink_multiaddr="/ip4/127.0.0.1/tcp/10001",
local_node_id=node_a,
send_back_node_id=node_b,
local_multiaddr="/ip4/127.0.0.1/tcp/10000",
send_back_multiaddr="/ip4/127.0.0.1/tcp/10001",
)
state = State()
@@ -27,4 +27,4 @@ def test_state_serialization_roundtrip() -> None:
restored_state = State.model_validate_json(json_repr)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr
assert restored_state.model_dump_json() == json_repr

View File

@@ -63,18 +63,17 @@ class Topology(TopologyProto):
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def add_connection(
self,
connection: Connection,
self,
connection: Connection,
) -> None:
if connection.source_node_id not in self._node_id_to_rx_id_map:
self.add_node(Node(node_id=connection.source_node_id))
if connection.sink_node_id not in self._node_id_to_rx_id_map:
self.add_node(Node(node_id=connection.sink_node_id))
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(Node(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(Node(node_id=connection.send_back_node_id))
src_id = self._node_id_to_rx_id_map[connection.source_node_id]
sink_id = self._node_id_to_rx_id_map[connection.sink_node_id]
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
@@ -89,15 +88,15 @@ class Topology(TopologyProto):
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
def update_node_profile(self, node_id: NodeId, node_profile: NodePerformanceProfile) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def get_connection_profile(self, connection: Connection) -> ConnectionProfile | None:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
@@ -112,7 +111,7 @@ class Topology(TopologyProto):
def remove_connection(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
if self._is_bridge(connection):
orphan_node_ids = self._get_orphan_node_ids(connection.source_node_id, connection)
orphan_node_ids = self._get_orphan_node_ids(connection.local_node_id, connection)
for orphan_node_id in orphan_node_ids:
orphan_node_rx_id = self._node_id_to_rx_id_map[orphan_node_id]
self._graph.remove_node(orphan_node_rx_id)
@@ -122,16 +121,16 @@ class Topology(TopologyProto):
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
del self._rx_id_to_node_id_map[rx_idx]
def get_cycles(self) -> list[list[Node]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[Node]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def _is_bridge(self, connection: Connection) -> bool:
edge_idx = self._edge_id_to_rx_id_map[connection]
graph_copy = self._graph.copy().to_undirected()
@@ -141,17 +140,17 @@ class Topology(TopologyProto):
components_after = rx.number_connected_components(graph_copy)
return components_after > components_before
def _get_orphan_node_ids(self, master_node_id: NodeId, connection: Connection) -> list[NodeId]:
edge_idx = self._edge_id_to_rx_id_map[connection]
graph_copy = self._graph.copy().to_undirected()
graph_copy.remove_edge_from_index(edge_idx)
components = rx.connected_components(graph_copy)
orphan_node_rx_ids: set[int] = set()
orphan_node_rx_ids: set[int] = set()
master_node_rx_id = self._node_id_to_rx_id_map[master_node_id]
for component in components:
if master_node_rx_id not in component:
orphan_node_rx_ids.update(component)
return [self._rx_id_to_node_id_map[rx_id] for rx_id in orphan_node_rx_ids]

View File

@@ -11,15 +11,17 @@ class ID(str):
@classmethod
def __get_pydantic_core_schema__(
cls,
_source: type[Any],
handler: GetCoreSchemaHandler
cls,
_source: type[Any],
handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Reuse the alreadydefined schema for `str`
return handler.generate_schema(str)
class NodeId(ID):
pass
class CommandId(ID):
pass

View File

@@ -177,6 +177,10 @@ class TopologyEdgeCreated(_BaseEvent[_EventType.TopologyEdgeCreated]):
class TopologyEdgeReplacedAtomically(_BaseEvent[_EventType.TopologyEdgeReplacedAtomically]):
"""
TODO: delete this????
"""
event_type: Literal[_EventType.TopologyEdgeReplacedAtomically] = _EventType.TopologyEdgeReplacedAtomically
edge: Connection
edge_profile: ConnectionProfile
@@ -186,6 +190,7 @@ class TopologyEdgeDeleted(_BaseEvent[_EventType.TopologyEdgeDeleted]):
event_type: Literal[_EventType.TopologyEdgeDeleted] = _EventType.TopologyEdgeDeleted
edge: Connection
_Event = Union[
TaskCreated,
TaskStateUpdated,
@@ -263,8 +268,6 @@ def _check_event_type_consistency():
_check_event_type_consistency()
Event = Annotated[_Event, Field(discriminator="event_type")]
"""Type of events, a discriminated union."""
@@ -276,4 +279,4 @@ Event = Annotated[_Event, Field(discriminator="event_type")]
#
# class TimerFired(_BaseEvent[_EventType.TimerFired]):
# event_type: Literal[_EventType.TimerFired] = _EventType.TimerFired
# timer_id: TimerId
# timer_id: TimerId

View File

@@ -7,31 +7,33 @@ from shared.types.profiling import ConnectionProfile, NodePerformanceProfile
class Connection(BaseModel):
source_node_id: NodeId
sink_node_id: NodeId
source_multiaddr: str
sink_multiaddr: str
local_node_id: NodeId
send_back_node_id: NodeId
local_multiaddr: str
send_back_multiaddr: str
connection_profile: ConnectionProfile | None = None
# required for Connection to be used as a key
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
def __hash__(self) -> int:
return hash(
(
self.source_node_id,
self.sink_node_id,
self.source_multiaddr,
self.sink_multiaddr,
)
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.local_multiaddr,
self.send_back_multiaddr,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.source_node_id == other.source_node_id
and self.sink_node_id == other.sink_node_id
and self.source_multiaddr == other.source_multiaddr
and self.sink_multiaddr == other.sink_multiaddr
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.local_multiaddr == other.local_multiaddr
and self.send_back_multiaddr == other.send_back_multiaddr
)
@@ -44,8 +46,8 @@ class TopologyProto(Protocol):
def add_node(self, node: Node) -> None: ...
def add_connection(
self,
connection: Connection,
self,
connection: Connection,
) -> None: ...
def list_nodes(self) -> Iterable[Node]: ...

View File

@@ -1,7 +1,64 @@
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any, Type
from exo_pyo3_bindings import Keypair
from filelock import FileLock
from shared.constants import EXO_NODE_ID_KEYPAIR
def ensure_type[T](obj: Any, expected_type: Type[T]) -> T: # type: ignore
if not isinstance(obj, expected_type):
raise TypeError(f"Expected {expected_type}, got {type(obj)}") # type: ignore
return obj
# def make_async_iter[T]():
# """
# Creates a pair `<async-iter>, <put-to-iter>` of an asynchronous iterator
# and a synchronous function to put items into that iterator.
# """
#
# loop = asyncio.get_event_loop()
# queue: asyncio.Queue[T] = asyncio.Queue()
#
# def put(c: ConnectionUpdate) -> None:
# loop.call_soon_threadsafe(queue.put_nowait, (c,))
#
# async def get():
# while True:
# yield await queue.get()
#
# return get(), put
def get_node_id_keypair(path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR) -> Keypair:
"""
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
def lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
# operate with cross-process lock to avoid race conditions
with FileLock(lock_path(path)):
with open(path, 'a+b') as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
except RuntimeError as e: # on runtime error, assume corrupt file
logging.warning(f"Encountered runtime error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, 'w+b') as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, ConfigDict
from shared.apply import apply
from shared.db.sqlite import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.node_id import get_node_id_keypair
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
@@ -54,6 +53,7 @@ from shared.types.worker.runners import (
RunningRunnerStatus,
)
from shared.types.worker.shards import ShardMetadata
from shared.utils import get_node_id_keypair
from worker.download.download_utils import build_model_path
from worker.runner.runner_supervisor import RunnerSupervisor
from worker.utils.profile import start_polling_node_metrics
@@ -226,7 +226,6 @@ class Worker:
assigned_runner.status = ReadyRunnerStatus()
yield assigned_runner.status_update_event()
async def _execute_task_op(
self, op: ExecuteTaskOp
) -> AsyncGenerator[Event, None]:
@@ -308,7 +307,6 @@ class Worker:
# Ensure the task is cleaned up
await task
## Operation Planner
async def _execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
@@ -474,7 +472,7 @@ class Worker:
running_runner_count = 0
for other_runner_id, other_runner_status in state.runners.items():
if other_runner_id in instance.shard_assignments.node_to_runner.values() and \
isinstance(other_runner_status, RunningRunnerStatus):
isinstance(other_runner_status, RunningRunnerStatus):
running_runner_count += 1
if running_runner_count == runner.shard_metadata.world_size - 1: