From 8d2536d92635d00d15b24284dd5d70304d09b036 Mon Sep 17 00:00:00 2001 From: Andrei Cravtov Date: Wed, 23 Jul 2025 13:11:29 +0100 Subject: [PATCH] Implemented basic discovery library in Rust + python bindings Co-authored-by: Gelu Vrabie Co-authored-by: Seth Howes Co-authored-by: Matt Beton --- .gitignore | 1 - .idea/.gitignore | 8 + .idea/LanguageServersSettings.xml | 16 + .idea/exo-v2.iml | 20 + .idea/externalDependencies.xml | 6 + .idea/inspectionProfiles/Project_Default.xml | 15 + .idea/misc.xml | 10 + .idea/modules.xml | 8 + .idea/pyright-overrides.xml | 17 + .idea/pyright.xml | 11 + .idea/vcs.xml | 6 + flake.lock | 29 +- flake.nix | 105 ++++-- justfile | 9 +- master/api.py | 3 +- master/main.py | 2 +- master/placement.py | 2 +- networking/README.md | 0 networking/topology/.gitignore | 1 - networking/topology/Cargo.lock | 171 --------- networking/topology/Cargo.toml | 14 - networking/topology/pyproject.toml | 21 -- networking/topology/src/lib.rs | 15 - .../topology/src/networking/__init__.py | 5 - networking/topology/src/networking/_core.pyi | 0 pyproject.toml | 16 +- rust/.gitignore | 11 + rust/Cargo.toml | 166 ++++++++ rust/clippy.toml | 2 + rust/discovery/Cargo.toml | 38 ++ rust/discovery/src/behaviour.rs | 61 +++ rust/discovery/src/lib.rs | 137 +++++++ rust/discovery/src/transport.rs | 80 ++++ rust/discovery/tests/dummy.rs | 8 + rust/exo_pyo3_bindings/Cargo.toml | 76 ++++ rust/exo_pyo3_bindings/README.md | 1 + rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi | 148 ++++++++ rust/exo_pyo3_bindings/pyproject.toml | 35 ++ rust/exo_pyo3_bindings/src/bin/stub_gen.rs | 32 ++ rust/exo_pyo3_bindings/src/discovery.rs | 353 ++++++++++++++++++ rust/exo_pyo3_bindings/src/lib.rs | 101 +++++ .../src/pylibp2p/connection.rs | 36 ++ rust/exo_pyo3_bindings/src/pylibp2p/ident.rs | 130 +++++++ rust/exo_pyo3_bindings/src/pylibp2p/mod.rs | 3 + .../src/pylibp2p/multiaddr.rs | 59 +++ rust/exo_pyo3_bindings/tests/dummy.rs | 54 +++ rust/exo_pyo3_bindings/tests/test_python.py | 72 ++++ rust/master_election/Cargo.toml | 41 ++ rust/master_election/src/cel/centrality.rs | 36 ++ rust/master_election/src/cel/messaging.rs | 57 +++ rust/master_election/src/cel/mod.rs | 333 +++++++++++++++++ rust/master_election/src/communicator.rs | 35 ++ rust/master_election/src/lib.rs | 44 +++ rust/master_election/src/participant.rs | 203 ++++++++++ rust/master_election/tests/dummy.rs | 8 + rust/rust-toolchain.toml | 2 + rust/util/Cargo.toml | 26 ++ rust/util/fn_pipe/Cargo.toml | 16 + rust/util/fn_pipe/proc/Cargo.toml | 20 + rust/util/fn_pipe/proc/src/lib.rs | 201 ++++++++++ rust/util/fn_pipe/src/lib.rs | 35 ++ rust/util/src/lib.rs | 53 +++ rust/util/src/nonempty.rs | 145 +++++++ shared/db/sqlite/connector.py | 3 +- shared/db/sqlite/types.py | 2 +- shared/event_loops/main.py | 2 +- shared/pyproject.toml | 3 + shared/tests/test_sqlite_connector.py | 8 +- shared/types/events/__init__.py | 99 +++++ shared/types/events/{common.py => _common.py} | 55 +-- shared/types/events/_events.py | 132 +++++++ shared/types/events/categories.py | 9 +- shared/types/events/commands.py | 3 +- shared/types/events/components.py | 2 +- shared/types/events/events.py | 137 ------- shared/types/events/registry.py | 107 ------ shared/types/events/sanity_checking.py | 75 ---- shared/types/worker/ops.py | 2 +- throwaway_tests/segfault_multiprocess.py | 31 ++ uv.lock | 56 ++- worker/main.py | 3 +- worker/pyproject.toml | 2 + worker/tests/test_worker_handlers.py | 3 +- 83 files changed, 3448 insertions(+), 655 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/LanguageServersSettings.xml create mode 100644 .idea/exo-v2.iml create mode 100644 .idea/externalDependencies.xml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/pyright-overrides.xml create mode 100644 .idea/pyright.xml create mode 100644 .idea/vcs.xml delete mode 100644 networking/README.md delete mode 100644 networking/topology/.gitignore delete mode 100644 networking/topology/Cargo.lock delete mode 100644 networking/topology/Cargo.toml delete mode 100644 networking/topology/pyproject.toml delete mode 100644 networking/topology/src/lib.rs delete mode 100644 networking/topology/src/networking/__init__.py delete mode 100644 networking/topology/src/networking/_core.pyi create mode 100644 rust/.gitignore create mode 100644 rust/Cargo.toml create mode 100644 rust/clippy.toml create mode 100644 rust/discovery/Cargo.toml create mode 100644 rust/discovery/src/behaviour.rs create mode 100644 rust/discovery/src/lib.rs create mode 100644 rust/discovery/src/transport.rs create mode 100644 rust/discovery/tests/dummy.rs create mode 100644 rust/exo_pyo3_bindings/Cargo.toml create mode 100644 rust/exo_pyo3_bindings/README.md create mode 100644 rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi create mode 100644 rust/exo_pyo3_bindings/pyproject.toml create mode 100644 rust/exo_pyo3_bindings/src/bin/stub_gen.rs create mode 100644 rust/exo_pyo3_bindings/src/discovery.rs create mode 100644 rust/exo_pyo3_bindings/src/lib.rs create mode 100644 rust/exo_pyo3_bindings/src/pylibp2p/connection.rs create mode 100644 rust/exo_pyo3_bindings/src/pylibp2p/ident.rs create mode 100644 rust/exo_pyo3_bindings/src/pylibp2p/mod.rs create mode 100644 rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs create mode 100644 rust/exo_pyo3_bindings/tests/dummy.rs create mode 100644 rust/exo_pyo3_bindings/tests/test_python.py create mode 100644 rust/master_election/Cargo.toml create mode 100644 rust/master_election/src/cel/centrality.rs create mode 100644 rust/master_election/src/cel/messaging.rs create mode 100644 rust/master_election/src/cel/mod.rs create mode 100644 rust/master_election/src/communicator.rs create mode 100644 rust/master_election/src/lib.rs create mode 100644 rust/master_election/src/participant.rs create mode 100644 rust/master_election/tests/dummy.rs create mode 100644 rust/rust-toolchain.toml create mode 100644 rust/util/Cargo.toml create mode 100644 rust/util/fn_pipe/Cargo.toml create mode 100644 rust/util/fn_pipe/proc/Cargo.toml create mode 100644 rust/util/fn_pipe/proc/src/lib.rs create mode 100644 rust/util/fn_pipe/src/lib.rs create mode 100644 rust/util/src/lib.rs create mode 100644 rust/util/src/nonempty.rs create mode 100644 shared/types/events/__init__.py rename shared/types/events/{common.py => _common.py} (67%) create mode 100644 shared/types/events/_events.py delete mode 100644 shared/types/events/events.py delete mode 100644 shared/types/events/registry.py delete mode 100644 shared/types/events/sanity_checking.py create mode 100644 throwaway_tests/segfault_multiprocess.py diff --git a/.gitignore b/.gitignore index e9a1c1ff..4cf7c64f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ */__pycache__ __pycache__ -networking/target/* *.so hosts_*.json \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000..13566b81 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/LanguageServersSettings.xml b/.idea/LanguageServersSettings.xml new file mode 100644 index 00000000..7d92ce2f --- /dev/null +++ b/.idea/LanguageServersSettings.xml @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/exo-v2.iml b/.idea/exo-v2.iml new file mode 100644 index 00000000..01e49642 --- /dev/null +++ b/.idea/exo-v2.iml @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/externalDependencies.xml b/.idea/externalDependencies.xml new file mode 100644 index 00000000..c16deb13 --- /dev/null +++ b/.idea/externalDependencies.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 00000000..84212658 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,15 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 00000000..4c4cf56c --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000..0ccec085 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/pyright-overrides.xml b/.idea/pyright-overrides.xml new file mode 100644 index 00000000..6fa46f1d --- /dev/null +++ b/.idea/pyright-overrides.xml @@ -0,0 +1,17 @@ + + + + + + \ No newline at end of file diff --git a/.idea/pyright.xml b/.idea/pyright.xml new file mode 100644 index 00000000..f3d73271 --- /dev/null +++ b/.idea/pyright.xml @@ -0,0 +1,11 @@ + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..94a25f7f --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/flake.lock b/flake.lock index b2380393..e4210f4f 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1749794982, - "narHash": "sha256-Kh9K4taXbVuaLC0IL+9HcfvxsSUx8dPB5s5weJcc9pc=", + "lastModified": 1752950548, + "narHash": "sha256-NS6BLD0lxOrnCiEOcvQCDVPXafX1/ek1dfJHX1nUIzc=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "ee930f9755f58096ac6e8ca94a1887e0534e2d81", + "rev": "c87b95e25065c028d31a94f06a62927d18763fdf", "type": "github" }, "original": { @@ -37,7 +37,28 @@ "root": { "inputs": { "flake-utils": "flake-utils", - "nixpkgs": "nixpkgs" + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1753156081, + "narHash": "sha256-N+8LM+zvS6cP+VG2vxgEEDCyX1T9EUq9wXTSvGwX9TM=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "8610c0f3801fc8dec7eb4b79c95fb39d16f38a80", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" } }, "systems": { diff --git a/flake.nix b/flake.nix index 44f676ac..ae20e4e2 100644 --- a/flake.nix +++ b/flake.nix @@ -1,19 +1,28 @@ { - description = "Exo development flake"; + description = "The development environment for Exo"; inputs = { nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; - flake-utils.url = "github:numtide/flake-utils"; + flake-utils = { + url = "github:numtide/flake-utils"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "nixpkgs"; + }; }; - outputs = { self, nixpkgs, flake-utils }: + outputs = { self, nixpkgs, rust-overlay, flake-utils }: flake-utils.lib.eachDefaultSystem (system: let - pkgs = import nixpkgs { inherit system; }; + overlays = [ (import rust-overlay) ]; + pkgs = (import nixpkgs) { + inherit system overlays; + }; # Go 1.23 compiler – align with go.mod go = pkgs.go_1_23; - # Build the networking/forwarder Go utility. forwarder = pkgs.buildGoModule { pname = "exo-forwarder"; @@ -25,40 +34,64 @@ # Only the main package at the repository root needs building. subPackages = [ "." ]; }; + + buildInputs = with pkgs; [ + ]; + nativeBuildInputs = with pkgs; [ + # This sets up the rust suite, automatically selecting the latest nightly version + (rust-bin.selectLatestNightlyWith + (toolchain: toolchain.default.override { + extensions = [ "rust-src" "clippy" ]; + })) + ]; in - { - packages = { - inherit forwarder; - default = forwarder; - }; + { + packages = { + inherit forwarder; + default = forwarder; + }; - apps.forwarder = { - type = "app"; - program = "${forwarder}/bin/forwarder"; - }; - apps.python-lsp = { - type = "app"; - program = "${pkgs.basedpyright}/bin/basedpyright-langserver"; - }; - apps.default = self.apps.${system}.forwarder; + apps = { + forwarder = { + type = "app"; + program = "${forwarder}/bin/forwarder"; + }; + python-lsp = { + type = "app"; + program = "${pkgs.basedpyright}/bin/basedpyright-langserver"; + }; + default = self.apps.${system}.forwarder; + }; - devShells.default = pkgs.mkShell { - packages = [ - pkgs.python313 - pkgs.uv - pkgs.just - pkgs.protobuf - pkgs.rustc - pkgs.cargo - pkgs.basedpyright - pkgs.ruff - go - ]; + devShells.default = pkgs.mkShell { + packages = [ + pkgs.python313 + pkgs.uv + pkgs.just + pkgs.protobuf + pkgs.basedpyright + pkgs.ruff + go + ]; - shellHook = '' - export GOPATH=$(mktemp -d) - ''; - }; - } + # TODO: change this into exported env via nix directly??? + shellHook = '' + export GOPATH=$(mktemp -d) + ''; + + nativeBuildInputs = with pkgs; [ + cargo-expand + nixpkgs-fmt + cmake + ] ++ buildInputs ++ nativeBuildInputs; + + # fixes libstdc++.so issues and libgl.so issues +# LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH"; + LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib"; + + # exports basedpyright path so tools can discover it + BASEDPYRIGHT_BIN_PATH = "${pkgs.basedpyright}/bin/"; + }; + } ); } \ No newline at end of file diff --git a/justfile b/justfile index 6cb6fc86..5865b22e 100644 --- a/justfile +++ b/justfile @@ -17,13 +17,16 @@ lint-check: uv run ruff check master worker shared engines/* test: - uv run pytest master worker shared engines/* + uv run pytest master worker shared engines/* rust/exo_pyo3_bindings/tests check: - basedpyright --project pyproject.toml + uv run basedpyright --project pyproject.toml sync: - uv sync --all-packages --reinstall + uv sync --all-packages + +sync-clean: + uv sync --all-packages --force-reinstall protobufs: just regenerate-protobufs diff --git a/master/api.py b/master/api.py index 28c78e48..f07e81f5 100644 --- a/master/api.py +++ b/master/api.py @@ -10,10 +10,9 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.types.events import ChunkGenerated, Event from shared.types.events.chunks import TokenChunk from shared.types.events.components import EventFromEventLog -from shared.types.events.events import ChunkGenerated -from shared.types.events.registry import Event from shared.types.request import APIRequest, RequestId from shared.types.tasks import ChatCompletionTaskParams diff --git a/master/main.py b/master/main.py index cb59ec45..3e99f808 100644 --- a/master/main.py +++ b/master/main.py @@ -8,8 +8,8 @@ from shared.db.sqlite.config import EventLogConfig from shared.db.sqlite.connector import AsyncSQLiteEventStorage from shared.db.sqlite.event_log_manager import EventLogManager from shared.types.common import NodeId +from shared.types.events import ChunkGenerated from shared.types.events.chunks import TokenChunk -from shared.types.events.events import ChunkGenerated from shared.types.request import APIRequest, RequestId diff --git a/master/placement.py b/master/placement.py index be0d8f41..b9eb7d70 100644 --- a/master/placement.py +++ b/master/placement.py @@ -2,7 +2,7 @@ from queue import Queue from typing import Mapping, Sequence from shared.topology import Topology -from shared.types.events.registry import Event +from shared.types.events import Event from shared.types.state import CachePolicy from shared.types.tasks import Task from shared.types.worker.instances import InstanceId, InstanceParams diff --git a/networking/README.md b/networking/README.md deleted file mode 100644 index e69de29b..00000000 diff --git a/networking/topology/.gitignore b/networking/topology/.gitignore deleted file mode 100644 index 9f970225..00000000 --- a/networking/topology/.gitignore +++ /dev/null @@ -1 +0,0 @@ -target/ \ No newline at end of file diff --git a/networking/topology/Cargo.lock b/networking/topology/Cargo.lock deleted file mode 100644 index 328ad73a..00000000 --- a/networking/topology/Cargo.lock +++ /dev/null @@ -1,171 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "autocfg" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" - -[[package]] -name = "cfg-if" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" - -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - -[[package]] -name = "indoc" -version = "2.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" - -[[package]] -name = "libc" -version = "0.2.174" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" - -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] - -[[package]] -name = "networking" -version = "0.1.0" -dependencies = [ - "pyo3", -] - -[[package]] -name = "once_cell" -version = "1.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" - -[[package]] -name = "portable-atomic" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" - -[[package]] -name = "proc-macro2" -version = "1.0.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pyo3" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "once_cell", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" -dependencies = [ - "heck", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "syn" -version = "2.0.104" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "target-lexicon" -version = "0.12.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" - -[[package]] -name = "unicode-ident" -version = "1.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" - -[[package]] -name = "unindent" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" diff --git a/networking/topology/Cargo.toml b/networking/topology/Cargo.toml deleted file mode 100644 index 6e458e40..00000000 --- a/networking/topology/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "networking" -version = "0.1.0" -edition = "2021" - -[lib] -name = "_core" -# "cdylib" is necessary to produce a shared library for Python to import from. -crate-type = ["cdylib"] - -[dependencies] -# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so) -# "abi3-py39" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.9 -pyo3 = { version = "0.22.4", features = ["extension-module", "abi3-py39"] } diff --git a/networking/topology/pyproject.toml b/networking/topology/pyproject.toml deleted file mode 100644 index f2e82e89..00000000 --- a/networking/topology/pyproject.toml +++ /dev/null @@ -1,21 +0,0 @@ -[project] -name = "exo-networking" -version = "0.1.0" -description = "Add your description here" -authors = [ - { name = "Arbion Halili", email = "99731180+ToxicPine@users.noreply.github.com" } -] -requires-python = ">=3.13" -dependencies = [] - -[project.scripts] -networking = "networking:main" - -[tool.maturin] -module-name = "networking._core" -python-packages = ["networking"] -python-source = "src" - -[build-system] -requires = ["maturin>=1.0,<2.0"] -build-backend = "maturin" diff --git a/networking/topology/src/lib.rs b/networking/topology/src/lib.rs deleted file mode 100644 index 915d8a39..00000000 --- a/networking/topology/src/lib.rs +++ /dev/null @@ -1,15 +0,0 @@ -use pyo3::prelude::*; - -#[pyfunction] -fn hello_from_bin() -> String { - "Hello from networking!".to_string() -} - -/// A Python module implemented in Rust. The name of this function must match -/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to -/// import the module. -#[pymodule] -fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(hello_from_bin, m)?)?; - Ok(()) -} diff --git a/networking/topology/src/networking/__init__.py b/networking/topology/src/networking/__init__.py deleted file mode 100644 index e357cd98..00000000 --- a/networking/topology/src/networking/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from networking._core import hello_from_bin - - -def main() -> None: - print(hello_from_bin()) diff --git a/networking/topology/src/networking/_core.pyi b/networking/topology/src/networking/_core.pyi deleted file mode 100644 index e69de29b..00000000 diff --git a/pyproject.toml b/pyproject.toml index d4573c85..7d8aad79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ dependencies = [ "exo-master", "exo-worker", "types-aiofiles>=24.1.0.20250708", + "typeguard>=4.4.4", + "pydantic>=2.11.7" ] # dependencies only required for development @@ -37,7 +39,7 @@ members = [ "worker", "shared", "engines/*", - "networking/topology", + "rust/exo_pyo3_bindings", ] [tool.uv.sources] @@ -45,7 +47,7 @@ exo-shared = { workspace = true } exo-master = { workspace = true } exo-worker = { workspace = true } exo-engine-mlx = { workspace = true } -exo-networking = { workspace = true } +exo-pyo3-bindings = { workspace = true } [build-system] requires = ["hatchling"] @@ -66,9 +68,9 @@ only-include = ["pyproject.toml", "README.md"] # type-checker configuration ### -[tool.basedpyright] +[tool.basedpyright] typeCheckingMode = "strict" -failOnWarnings = true +failOnWarnings = true reportAny = "error" reportUnknownVariableType = "error" @@ -80,11 +82,11 @@ reportUnnecessaryCast = "error" reportUnnecessaryTypeIgnoreComment = "error" include = ["master", "worker", "shared", "engines/*"] -pythonVersion = "3.13" +pythonVersion = "3.13" pythonPlatform = "Darwin" stubPath = "shared/protobufs/types" -ignore = [ +ignore = [ "shared/protobufs/types/**/*", ] @@ -111,4 +113,4 @@ extend-select = ["I", "N", "B", "A", "PIE", "SIM"] [tool.pytest.ini_options] pythonpath = "." -asyncio_mode = "auto" +asyncio_mode = "auto" \ No newline at end of file diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 00000000..e9c71ef3 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,11 @@ +/target +compile +.* +./*.wacc +*.s +*.core +.wacc +*.png +*.dot + +Cargo.lock \ No newline at end of file diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 00000000..97c472da --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,166 @@ +[workspace] +resolver = "3" +members = [ + "discovery", + "exo_pyo3_bindings", + "master_election", + "util", + "util/fn_pipe", + "util/fn_pipe/proc", +] + +[workspace.package] +version = "0.0.1" +edition = "2024" + +[profile.dev] +opt-level = 1 +debug = true + +[profile.release] +opt-level = 3 + +# Common shared dependendencies configured once at the workspace +# level, to be re-used more easily across workspace member crates. +# +# Common configurations include versions, paths, features, etc. +[workspace.dependencies] +## Crate members as common dependencies +discovery = { path = "discovery" } +master_election = { path = "master_election" } +util = { path = "util" } +exo_pyo3_bindings = { path = "exo_pyo3_bindings" } +fn_pipe = { path = "util/fn_pipe" } +fn_pipe_proc = { path = "util/fn_pipe/proc" } + + +# Proc-macro authoring tools +syn = "2.0" +quote = "1.0" +proc-macro2 = "1.0" +darling = "0.20" +# Macro dependecies +extend = "1.2" +delegate = "0.13" +impl-trait-for-tuples = "0.2" +clap = "4.5" +derive_more = { version = "2.0.1", features = ["display"] } +# Utility dependencies +itertools = "0.14" +thiserror = "2" +internment = "0.8" +recursion = "0.5" +regex = "1.11" +once_cell = "1.21" +thread_local = "1.1" +bon = "3.4" +generativity = "1.1" +anyhow = "1.0" +keccak-const = "0.2" +# Functional generics/lenses frameworks +frunk_core = "0.4" +frunk = "0.4" +frunk_utils = "0.2" +frunk-enum-core = "0.3" +# Async dependencies +tokio = "1.46" +futures = "0.3" +futures-util = "0.3" +# Data structures +either = "1.15" +ordered-float = "5.0" +ahash = "0.8" +# networking +libp2p = "0.56" +libp2p-tcp = "0.44" +# interop +pyo3 = "0.25" +#pyo3-stub-gen = { git = "https://github.com/Jij-Inc/pyo3-stub-gen.git", rev = "d2626600e52452e71095c57e721514de748d419d" } # v0.11 not yet published to crates +pyo3-stub-gen = { git = "https://github.com/cstruct/pyo3-stub-gen.git", rev = "2efddde7dcffc462868aa0e4bbc46877c657a0fe" } # This fork adds support for type overrides => not merged yet!!! +pyo3-async-runtimes = "0.25" + +[workspace.lints.rust] +static_mut_refs = "warn" # Or use "warn" instead of deny +incomplete_features = "allow" + +# Clippy's lint category level configurations; +# every member crate needs to inherit these by adding +# +# ```toml +# [lints] +# workspace = true +# ``` +# +# to their `Cargo.toml` files +[workspace.lints.clippy] +# Clippy lint categories meant to be enabled all at once +correctness = { level = "deny", priority = -1 } +suspicious = { level = "warn", priority = -1 } +style = { level = "warn", priority = -1 } +complexity = { level = "warn", priority = -1 } +perf = { level = "warn", priority = -1 } +pedantic = { level = "warn", priority = -1 } +nursery = { level = "warn", priority = -1 } +cargo = { level = "warn", priority = -1 } + +# Individual Clippy lints from the `restriction` category +arithmetic_side_effects = "warn" +as_conversions = "warn" +assertions_on_result_states = "warn" +clone_on_ref_ptr = "warn" +decimal_literal_representation = "warn" +default_union_representation = "warn" +deref_by_slicing = "warn" +disallowed_script_idents = "deny" +else_if_without_else = "warn" +empty_enum_variants_with_brackets = "warn" +empty_structs_with_brackets = "warn" +error_impl_error = "warn" +exit = "deny" +expect_used = "warn" +float_cmp_const = "warn" +get_unwrap = "warn" +if_then_some_else_none = "warn" +impl_trait_in_params = "warn" +indexing_slicing = "warn" +infinite_loop = "warn" +let_underscore_must_use = "warn" +let_underscore_untyped = "warn" +lossy_float_literal = "warn" +mem_forget = "warn" +missing_inline_in_public_items = "warn" +multiple_inherent_impl = "warn" +multiple_unsafe_ops_per_block = "warn" +mutex_atomic = "warn" +non_zero_suggestions = "warn" +panic = "warn" +partial_pub_fields = "warn" +pattern_type_mismatch = "warn" +pub_without_shorthand = "warn" +rc_buffer = "warn" +rc_mutex = "warn" +redundant_type_annotations = "warn" +renamed_function_params = "warn" +rest_pat_in_fully_bound_structs = "warn" +same_name_method = "warn" +self_named_module_files = "deny" +semicolon_inside_block = "warn" +shadow_same = "warn" +shadow_unrelated = "warn" +str_to_string = "warn" +string_add = "warn" +string_lit_chars_any = "warn" +string_to_string = "warn" +tests_outside_test_module = "warn" +todo = "warn" +try_err = "warn" +undocumented_unsafe_blocks = "warn" +unnecessary_safety_comment = "warn" +unnecessary_safety_doc = "warn" +unneeded_field_pattern = "warn" +unseparated_literal_suffix = "warn" +unused_result_ok = "warn" +unused_trait_names = "warn" +unwrap_used = "warn" +verbose_file_reads = "warn" +static_mut_refs = "warn" \ No newline at end of file diff --git a/rust/clippy.toml b/rust/clippy.toml new file mode 100644 index 00000000..6d5a6187 --- /dev/null +++ b/rust/clippy.toml @@ -0,0 +1,2 @@ +# we can manually exclude false-positive lint errors for dual packages (if in dependencies) +#allowed-duplicate-crates = ["hashbrown"] \ No newline at end of file diff --git a/rust/discovery/Cargo.toml b/rust/discovery/Cargo.toml new file mode 100644 index 00000000..6ca9ef17 --- /dev/null +++ b/rust/discovery/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "discovery" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +name = "discovery" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +# macro dependencies +extend = { workspace = true } +delegate = { workspace = true } +impl-trait-for-tuples = { workspace = true } +derive_more = { workspace = true } + +# Async +tokio = { workspace = true, features = ["full"] } +futures = { workspace = true } + +# utility dependencies +#util = { workspace = true } +#fn_pipe = { workspace = true } +thiserror = { workspace = true } +#internment = { workspace = true } +#recursion = { workspace = true } +#generativity = { workspace = true } +#itertools = { workspace = true } +tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] } +keccak-const = { workspace = true } + +# Networking +libp2p = { workspace = true, features = ["full"] } \ No newline at end of file diff --git a/rust/discovery/src/behaviour.rs b/rust/discovery/src/behaviour.rs new file mode 100644 index 00000000..52a7032e --- /dev/null +++ b/rust/discovery/src/behaviour.rs @@ -0,0 +1,61 @@ +use crate::alias::AnyResult; +use libp2p::swarm::NetworkBehaviour; +use libp2p::{gossipsub, identity, mdns}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::time::Duration; + +/// Custom network behavior for `discovery` network; it combines [`mdns::tokio::Behaviour`] for +/// the actual mDNS discovery, and [`gossipsub::Behaviour`] for PubSub functionality. +#[derive(NetworkBehaviour)] +pub struct DiscoveryBehaviour { + pub mdns: mdns::tokio::Behaviour, + pub gossipsub: gossipsub::Behaviour, +} + +fn mdns_behaviour(keypair: &identity::Keypair) -> AnyResult { + use mdns::{tokio, Config}; + + // mDNS config => enable IPv6 + let mdns_config = Config { + // enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work + ..Default::default() + }; + + let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id()); + Ok(mdns_behaviour?) +} + +fn gossipsub_behaviour(keypair: &identity::Keypair) -> AnyResult { + use gossipsub::ConfigBuilder; + + // To content-address message, we can take the hash of message and use it as an ID. + let message_id_fn = |message: &gossipsub::Message| { + let mut s = DefaultHasher::new(); + message.data.hash(&mut s); + gossipsub::MessageId::from(s.finish().to_string()) + }; + + let gossipsub_config = ConfigBuilder::default() + // .mesh_n_low(1 + .mesh_n(1) // this is for debugging!!! change to 6 + // .mesh_n_for_topic(1, topic.hash()) // this is for debugging!!! change to 6 + // .mesh_n_high(1) + .heartbeat_interval(Duration::from_secs(10)) // This is set to aid debugging by not cluttering the log space + .validation_mode(gossipsub::ValidationMode::None) // This sets the kind of message validation. Skip signing for speed. + .message_id_fn(message_id_fn) // content-address messages. No two messages of the same content will be propagated. + .build()?; // Temporary hack because `build` does not return a proper `std::error::Error`. + + // build a gossipsub network behaviour + let gossipsub_behavior = gossipsub::Behaviour::new( + gossipsub::MessageAuthenticity::Signed(keypair.clone()), + gossipsub_config, + )?; + Ok(gossipsub_behavior) +} + +pub fn discovery_behaviour(keypair: &identity::Keypair) -> AnyResult { + Ok(DiscoveryBehaviour { + gossipsub: gossipsub_behaviour(keypair)?, + mdns: mdns_behaviour(keypair)?, + }) +} diff --git a/rust/discovery/src/lib.rs b/rust/discovery/src/lib.rs new file mode 100644 index 00000000..17cb78ca --- /dev/null +++ b/rust/discovery/src/lib.rs @@ -0,0 +1,137 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation +//! +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +// #![feature(stmt_expr_attributes)] +// #![feature(unboxed_closures)] +// #![feature(assert_matches)] +// #![feature(async_fn_in_dyn_trait)] +// #![feature(async_for_loop)] +// #![feature(auto_traits)] +// #![feature(negative_impls)] + +use crate::behaviour::{discovery_behaviour, DiscoveryBehaviour}; +use crate::transport::discovery_transport; +use libp2p::{identity, Swarm, SwarmBuilder}; + +pub mod behaviour; +pub mod transport; + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias { + use std::error::Error; + + pub type AnyError = Box; + pub type AnyResult = Result; +} + +/// Namespace for crate-wide extension traits/methods +pub(crate) mod ext {} + +pub(crate) mod private { + /// Sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} + +/// Create and configure a swarm, and start listening to all ports/OS. +#[inline] +pub fn discovery_swarm(keypair: identity::Keypair) -> alias::AnyResult> { + let mut swarm = SwarmBuilder::with_existing_identity(keypair) + .with_tokio() + .with_other_transport(discovery_transport)? + .with_behaviour(discovery_behaviour)? + .build(); + + // Listen on all interfaces and whatever port the OS assigns + // swarm.listen_on("/ip4/0.0.0.0/udp/0/quic-v1".parse()?)?; // TODO: make this + swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?; + + Ok(swarm) +} + +// TODO: - ensure that all changes to connections means a Disconnect/Reconnect event fired, i.e. if it switched IPs slighty or something +// - ensure that all links are unique, i.e. each connection has some kind of uniquely identifiable hash/multiaddress/whatever => temporally unique??? +// - need pnet config, so that forwarder & discovery don't interfere with each-other +// - discovery network needs persistence, so swarm created from existing identity (passed as arg) +// - connect/disconnect events etc. should be handled with callbacks +// - DON'T need gossipsub JUST yet, only mDNS for discovery => potentially use something else instead of gossipsub + +#[cfg(test)] +mod tests { + use crate::alias::AnyResult; + use crate::behaviour::DiscoveryBehaviourEvent; + use crate::discovery_swarm; + use futures::stream::StreamExt as _; + use libp2p::{gossipsub, identity, mdns, swarm::SwarmEvent}; + use std::hash::Hash; + use tokio::{io, io::AsyncBufReadExt as _, select}; + use tracing_subscriber::filter::LevelFilter; + use tracing_subscriber::util::SubscriberInitExt as _; + use tracing_subscriber::EnvFilter; + + #[tokio::test] + async fn chatroom_test() -> AnyResult<()> { + let _ = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::DEBUG.into())) + .try_init(); + + // Configure swarm + let mut swarm = discovery_swarm(identity::Keypair::generate_ed25519())?; + + // Create a Gossipsub topic & subscribe + let topic = gossipsub::IdentTopic::new("test-net"); + swarm.behaviour_mut().gossipsub.subscribe(&topic)?; + + // Read full lines from stdin + let mut stdin = io::BufReader::new(io::stdin()).lines(); + println!( + "Enter messages via STDIN and they will be sent to connected peers using Gossipsub" + ); + + // Kick it off + loop { + select! { + Ok(Some(line)) = stdin.next_line() => { + if let Err(e) = swarm + .behaviour_mut().gossipsub + .publish(topic.clone(), line.as_bytes()) { + println!("Publish error: {e:?}"); + } + } + event = swarm.select_next_some() => match event { + SwarmEvent::Behaviour(DiscoveryBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => { + for (peer_id, multiaddr) in list { + println!("mDNS discovered a new peer: {peer_id} on {multiaddr}"); + swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id); + } + }, + SwarmEvent::Behaviour(DiscoveryBehaviourEvent::Mdns(mdns::Event::Expired(list))) => { + for (peer_id, multiaddr) in list { + println!("mDNS discover peer has expired: {peer_id} on {multiaddr}"); + swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id); + } + }, + SwarmEvent::Behaviour(DiscoveryBehaviourEvent::Gossipsub(gossipsub::Event::Message { + propagation_source: peer_id, + message_id: id, + message, + })) => println!( + "\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n", + String::from_utf8_lossy(&message.data), + ), + SwarmEvent::NewListenAddr { address, .. } => { + println!("Local node is listening on {address}"); + } + e => { + println!("Other event {e:?}"); + } + } + } + } + } +} diff --git a/rust/discovery/src/transport.rs b/rust/discovery/src/transport.rs new file mode 100644 index 00000000..ee7213d8 --- /dev/null +++ b/rust/discovery/src/transport.rs @@ -0,0 +1,80 @@ +use crate::alias::AnyResult; +use futures::{AsyncRead, AsyncWrite}; +use keccak_const::Sha3_256; +use libp2p::{ + core::{muxing, transport::Boxed}, identity, + noise, + pnet, quic, yamux, PeerId, Transport as _, +}; +use std::any::Any; + +/// Key used for discovery's private network. See [`pnet_upgrade`] for more. +const PNET_PRESHARED_KEY: [u8; 32] = Sha3_256::new().update(b"exo_discovery_network").finalize(); + +/// Make `discovery` run on a private network, as to not clash with the `forwarder` network. +/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers. +fn pnet_upgrade( + socket: Socket, + _ignored: impl Any, +) -> impl Future, pnet::PnetError>> +where + Socket: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + pnet::PnetConfig::new(pnet::PreSharedKey::new(PNET_PRESHARED_KEY)).handshake(socket) +} + +/// TCP/IP transport layer configuration. +fn tcp_transport( + keypair: &identity::Keypair, +) -> AnyResult> { + use libp2p::{ + core::upgrade::Version, + tcp::{tokio, Config}, + }; + + // `TCP_NODELAY` enabled => avoid latency + let tcp_config = Config::default().nodelay(true); + + // V1 + lazy flushing => 0-RTT negotiation + let upgrade_version = Version::V1Lazy; + + // Noise is faster than TLS + we don't care much for security + let noise_config = noise::Config::new(keypair)?; + //let tls_config = tls::Config::new(keypair)?; // TODO: add this in if needed?? => look into how `.with_tcp` does it... + + // Use default Yamux config for multiplexing + let yamux_config = yamux::Config::default(); + + // Create new Tokio-driven TCP/IP transport layer + let base_transport = tokio::Transport::new(tcp_config) + .and_then(pnet_upgrade) + .upgrade(upgrade_version) + .authenticate(noise_config) + .multiplex(yamux_config); + + // Return boxed transport (to flatten complex type) + Ok(base_transport.boxed()) +} + +/// QUIC transport layer configuration. +fn quic_transport(keypair: &identity::Keypair) -> Boxed<(PeerId, quic::Connection)> { + use libp2p::quic::{tokio, Config}; + + let quic_config = Config::new(keypair); + let base_transport = tokio::Transport::new(quic_config).boxed(); + //.and_then(); // As of now, QUIC doesn't support PNet's.., ;( TODO: figure out in future how to do + unimplemented!("you cannot use this yet !!!"); + base_transport +} + +/// Overall composed transport-layer configuration for the `discovery` network. +pub fn discovery_transport( + keypair: &identity::Keypair, +) -> AnyResult> { + // TODO: when QUIC is figured out with PNET, re-enable this + // Ok(tcp_transport(keypair)? + // .or_transport(quic_transport(keypair)) + // .boxed()) + + tcp_transport(keypair) +} diff --git a/rust/discovery/tests/dummy.rs b/rust/discovery/tests/dummy.rs new file mode 100644 index 00000000..d82c6eb1 --- /dev/null +++ b/rust/discovery/tests/dummy.rs @@ -0,0 +1,8 @@ +// maybe this will hold test in the future...?? + +#[cfg(test)] +mod tests { + #[test] + fn does_nothing() { + } +} \ No newline at end of file diff --git a/rust/exo_pyo3_bindings/Cargo.toml b/rust/exo_pyo3_bindings/Cargo.toml new file mode 100644 index 00000000..db37d027 --- /dev/null +++ b/rust/exo_pyo3_bindings/Cargo.toml @@ -0,0 +1,76 @@ +[package] +name = "exo_pyo3_bindings" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +path = "src/lib.rs" +name = "exo_pyo3_bindings" + +# "cdylib" needed to produce shared library for Python to import +# "rlib" needed for stub-gen to run +crate-type = ["cdylib", "rlib"] + +[[bin]] +path = "src/bin/stub_gen.rs" +name = "stub_gen" +doc = false + +[lints] +workspace = true + +[dependencies] +discovery = { workspace = true } + +# interop +pyo3 = { workspace = true, features = [ + "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11 + "nightly", # enables better-supported GIL integration + "experimental-async", # async support in #[pyfunction] & #[pymethods] + #"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation + #"py-clone", # adding Clone-ing of `Py` without GIL (may cause panics - remove if panics happen) + "multiple-pymethods", # allows multiple #[pymethods] sections per class + + # integrations with other libraries + "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational", + "ordered-float", "rust_decimal", "smallvec", + # "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde", +] } +pyo3-stub-gen = { workspace = true } +pyo3-async-runtimes = { workspace = true, features = ["attributes", "tokio-runtime", "testing"] } + +# macro dependencies +extend = { workspace = true } +delegate = { workspace = true } +impl-trait-for-tuples = { workspace = true } +derive_more = { workspace = true } + +# async runtime +tokio = { workspace = true, features = ["full", "tracing"] } + +# utility dependencies +once_cell = "1.21.3" +thread_local = "1.1.9" +#util = { workspace = true } +#fn_pipe = { workspace = true } +thiserror = { workspace = true } +#internment = { workspace = true } +#recursion = { workspace = true } +#generativity = { workspace = true } +#itertools = { workspace = true } + + +# Tracing +#tracing = "0.1" +#tracing-subscriber = "0.3" +#console-subscriber = "0.1.5" +#tracing-log = "0.2.0" +env_logger = "0.11" +log = "0.4" +pyo3-log = "0.12" + + +# Networking +libp2p = { workspace = true, features = ["full"] } diff --git a/rust/exo_pyo3_bindings/README.md b/rust/exo_pyo3_bindings/README.md new file mode 100644 index 00000000..e739dd89 --- /dev/null +++ b/rust/exo_pyo3_bindings/README.md @@ -0,0 +1 @@ +TODO: do something here.... diff --git a/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi b/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi new file mode 100644 index 00000000..0cb78c74 --- /dev/null +++ b/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi @@ -0,0 +1,148 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401 + +import builtins +import collections.abc + +class ConnectionId: + r""" + TODO: documentation... + """ + @staticmethod + def new_unchecked(id:builtins.int) -> ConnectionId: + r""" + TODO: documentation + """ + def __repr__(self) -> builtins.str: ... + def __str__(self) -> builtins.str: ... + +class ConnectionUpdate: + @property + def peer_id(self) -> PeerId: + r""" + Identity of the peer that we have connected to. + """ + @property + def connection_id(self) -> ConnectionId: + r""" + Identifier of the connection. + """ + @property + def local_addr(self) -> Multiaddr: + r""" + Local connection address. + """ + @property + def send_back_addr(self) -> Multiaddr: + r""" + Address used to send back data to the remote. + """ + +class DiscoveryService: + def __new__(cls, identity:Keypair) -> DiscoveryService: ... + def add_connected_callback(self, callback:collections.abc.Callable[[ConnectionUpdate], None]) -> None: ... + def add_disconnected_callback(self, callback:collections.abc.Callable[[ConnectionUpdate], None]) -> None: ... + +class Keypair: + r""" + TODO: documentation... + """ + @staticmethod + def generate_ed25519() -> Keypair: + r""" + TODO: documentation + """ + @staticmethod + def generate_ecdsa() -> Keypair: + r""" + TODO: documentation + """ + @staticmethod + def generate_secp256k1() -> Keypair: + r""" + TODO: documentation + """ + @staticmethod + def from_protobuf_encoding(bytes:bytes) -> Keypair: + r""" + TODO: documentation + """ + @staticmethod + def rsa_from_pkcs8(bytes:bytes) -> Keypair: + r""" + TODO: documentation + """ + @staticmethod + def secp256k1_from_der(bytes:bytes) -> Keypair: + r""" + TODO: documentation + """ + @staticmethod + def ed25519_from_bytes(bytes:bytes) -> Keypair: + r""" + TODO: documentation + """ + @staticmethod + def ecdsa_from_bytes(bytes:bytes) -> Keypair: + r""" + TODO: documentation + """ + def to_protobuf_encoding(self) -> bytes: + r""" + TODO: documentation + """ + +class Multiaddr: + r""" + TODO: documentation... + """ + @staticmethod + def empty() -> Multiaddr: + r""" + TODO: documentation + """ + @staticmethod + def with_capacity(n:builtins.int) -> Multiaddr: + r""" + TODO: documentation + """ + def len(self) -> builtins.int: + r""" + TODO: documentation + """ + def is_empty(self) -> builtins.bool: + r""" + TODO: documentation + """ + def to_bytes(self) -> bytes: + r""" + TODO: documentation + """ + def __repr__(self) -> builtins.str: ... + def __str__(self) -> builtins.str: ... + +class PeerId: + r""" + TODO: documentation... + """ + @staticmethod + def random() -> PeerId: + r""" + TODO: documentation + """ + @staticmethod + def from_bytes(bytes:bytes) -> PeerId: + r""" + TODO: documentation + """ + def to_bytes(self) -> bytes: + r""" + TODO: documentation + """ + def to_base58(self) -> builtins.str: + r""" + TODO: documentation + """ + def __repr__(self) -> builtins.str: ... + def __str__(self) -> builtins.str: ... + diff --git a/rust/exo_pyo3_bindings/pyproject.toml b/rust/exo_pyo3_bindings/pyproject.toml new file mode 100644 index 00000000..1adf83a1 --- /dev/null +++ b/rust/exo_pyo3_bindings/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[project] +name = "exo_pyo3_bindings" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +authors = [ + { name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" } +] +requires-python = ">=3.13" +dependencies = [] + +[dependency-groups] +dev = [ + "exo_pyo3_bindings", + "pytest>=8.4.0", + "pytest-asyncio>=1.0.0", +] + +#[project.scripts] +#networking = "rust-bindings:main" + +[tool.maturin] +#purelib = true +#python-source = "python" +module-name = "exo_pyo3_bindings" +features = ["pyo3/extension-module", "pyo3/experimental-async"] + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" +asyncio_mode = "auto" \ No newline at end of file diff --git a/rust/exo_pyo3_bindings/src/bin/stub_gen.rs b/rust/exo_pyo3_bindings/src/bin/stub_gen.rs new file mode 100644 index 00000000..ac979ea5 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/bin/stub_gen.rs @@ -0,0 +1,32 @@ +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + let body = async { + env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")) + .init(); + let stub = exo_pyo3_bindings::stub_info()?; + stub.generate()?; + Ok(()) + }; + #[allow( + clippy::expect_used, + clippy::diverging_sub_expression, + clippy::needless_return + )] + { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("Failed building the Runtime"); + + let a = runtime.handle(); + + return runtime.block_on(body); + } +} + +// fn main() -> Result<()> { +// let stub = python_bindings::stub_info()?; +// stub.generate()?; +// Ok(()) +// } diff --git a/rust/exo_pyo3_bindings/src/discovery.rs b/rust/exo_pyo3_bindings/src/discovery.rs new file mode 100644 index 00000000..fc3dfa6c --- /dev/null +++ b/rust/exo_pyo3_bindings/src/discovery.rs @@ -0,0 +1,353 @@ +#![allow( + clippy::multiple_inherent_impl, + clippy::unnecessary_wraps, + clippy::unused_self, + clippy::needless_pass_by_value +)] + +use crate::ext::ResultExt; +use crate::pylibp2p::connection::PyConnectionId; +use crate::pylibp2p::ident::{PyKeypair, PyPeerId}; +use crate::pylibp2p::multiaddr::PyMultiaddr; +use crate::{alias, pyclass, MPSC_CHANNEL_SIZE}; +use discovery::behaviour::{DiscoveryBehaviour, DiscoveryBehaviourEvent}; +use discovery::discovery_swarm; +use libp2p::core::ConnectedPoint; +use libp2p::futures::StreamExt; +use libp2p::multiaddr::multiaddr; +use libp2p::swarm::dial_opts::DialOpts; +use libp2p::swarm::{ConnectionId, SwarmEvent, ToSwarm}; +use libp2p::{gossipsub, mdns, Multiaddr, PeerId, Swarm}; +use pyo3::prelude::{PyModule, PyModuleMethods as _}; +use pyo3::{pymethods, Bound, Py, PyObject, PyResult, PyTraverseError, PyVisit, Python}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +use std::convert::identity; +use std::error::Error; +use tokio::sync::mpsc; + +struct ConnectionUpdate { + /// Identity of the peer that we have connected to. + peer_id: PeerId, + /// Identifier of the connection. + connection_id: ConnectionId, + /// Local connection address. + local_addr: Multiaddr, + /// Address used to send back data to the remote. + send_back_addr: Multiaddr, +} + +#[gen_stub_pyclass] +#[pyclass(frozen, name = "ConnectionUpdate")] +#[derive(Debug, Clone)] +struct PyConnectionUpdate { + /// Identity of the peer that we have connected to. + #[pyo3(get)] + peer_id: PyPeerId, + /// Identifier of the connection. + #[pyo3(get)] + connection_id: PyConnectionId, + /// Local connection address. + #[pyo3(get)] + local_addr: PyMultiaddr, + /// Address used to send back data to the remote. + #[pyo3(get)] + send_back_addr: PyMultiaddr, +} + +impl PyConnectionUpdate { + fn from_connection_event( + ConnectionUpdate { + peer_id, + connection_id, + local_addr, + send_back_addr, + }: ConnectionUpdate, + ) -> Self { + Self { + peer_id: PyPeerId(peer_id), + connection_id: PyConnectionId(connection_id), + local_addr: PyMultiaddr(local_addr), + send_back_addr: PyMultiaddr(send_back_addr), + } + } +} + +enum IncomingDiscoveryMessage { + AddConnectedCallback(Box>), + AddDisconnectedCallback(Box>), +} + +#[allow(clippy::enum_glob_use)] +async fn discovery_task( + mut receiver: mpsc::Receiver, + mut swarm: Swarm, +) { + use DiscoveryBehaviourEvent::*; + use IncomingDiscoveryMessage::*; + use SwarmEvent::*; + use gossipsub::Event::*; + use mdns::Event::*; + + log::info!("RUST: discovery task started"); + + // create callbacks list + let mut connected_callbacks: Vec>> = vec![]; + let mut disconnected_callbacks: Vec>> = vec![]; + + loop { + tokio::select! { + message = receiver.recv() => { + // handle closed channel + let Some(message) = message else { + log::info!("RUST: channel closed"); + break; + }; + + // attach callbacks for event types + match message { + AddConnectedCallback(callback) => { + log::info!("RUST: received connected callback"); + connected_callbacks.push(callback); + } + AddDisconnectedCallback(callback) => { + log::info!("RUST: received disconnected callback"); + disconnected_callbacks.push(callback); + } + } + } + swarm_event = swarm.select_next_some() => { + match swarm_event { + Behaviour(Mdns(Discovered(list))) => { + for (peer_id, multiaddr) in list { + log::info!("RUST: mDNS discovered a new peer: {peer_id} on {multiaddr}"); + // TODO: this does the job of (actually) creating & maintaining connection + // but its coupled to gossipsub & also the connection isn't configured + // for setting "connection keep alive" in NetworkBehavior's ConnectionHandler + // >in future, make own small NetworkBehavior impl just to track this state + swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id); + } + } + Behaviour(Mdns(Expired(list))) => { + for (peer_id, multiaddr) in list { + log::info!("RUST: mDNS discover peer has expired: {peer_id} on {multiaddr}"); + swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id); + } + }, + Behaviour(Gossipsub(Message { + propagation_source: peer_id, + message_id: id, + message, + })) => log::info!( + "RUST: Got message: '{}' with id: {id} from peer: {peer_id}", + String::from_utf8_lossy(&message.data), + ), + ConnectionEstablished { + peer_id, + connection_id, + endpoint, + num_established: _num_established, + concurrent_dial_errors, + established_in: _established_in, + } => { + // log any connection errors + if let Some(concurrent_dial_errors) = concurrent_dial_errors { + for (multiaddr, error) in concurrent_dial_errors { + log::error!("Connection error: multiaddr={multiaddr}, error={error:?}"); + } + } + + // TODO: right now we assume we are using TCP/IP which treats all nodes + // as both dialers AND listeners. This means for each connection you will actually + // see TWO duplicate Connected events => Dialer & Listener + // SO ignore the Dialer & extract the info we need from Listener + // HOWEVER this makes the swarm implicitly rely on TCP/IP, so is brittle to changes + // e.g. adding QUIC protocol or something + // >As soon as we add anything other than TCP/IP, this must be updated or there will be broken code + let ConnectedPoint::Listener { local_addr, send_back_addr } = endpoint else { + log::warn!("Ignoring `ConnectedPoint::Dialer` event because for TCP/IP it has a dual `ConnectedPoint::Listener` event: {endpoint:?}"); + continue; + }; + + + // trigger callback on connected peer + for connected_callback in &connected_callbacks { + connected_callback(ConnectionUpdate { + peer_id, + connection_id, + local_addr: local_addr.clone(), + send_back_addr: send_back_addr.clone(), + }); + } + }, + ConnectionClosed { peer_id, connection_id, endpoint, num_established, cause } => { + // log any connection errors + if let Some(cause) = cause { + log::error!("Connection error: cause={cause:?}"); + } + + // TODO: right now we assume we are using TCP/IP which treats all nodes + // as both dialers AND listeners. This means for each connection you will actually + // see TWO duplicate Connected events => Dialer & Listener + // SO ignore the Dialer & extract the info we need from Listener + // HOWEVER this makes the swarm implicitly rely on TCP/IP, so is brittle to changes + // e.g. adding QUIC protocol or something + // >As soon as we add anything other than TCP/IP, this must be updated or there will be broken code + let ConnectedPoint::Listener { local_addr, send_back_addr } = endpoint else { + log::warn!("Ignoring `ConnectedPoint::Dialer` event because for TCP/IP it has a dual `ConnectedPoint::Listener` event: {endpoint:?}"); + continue; + }; + + // trigger callback on connected peer + for disconnected_callback in &disconnected_callbacks { + disconnected_callback(ConnectionUpdate { + peer_id, + connection_id, + local_addr: local_addr.clone(), + send_back_addr: send_back_addr.clone(), + }); + } + } + e => { + log::info!("RUST: Other event {e:?}"); + } + } + } + } + } + + log::info!("RUST: discovery task stopped"); +} + +#[gen_stub_pyclass] +#[pyclass(name = "DiscoveryService")] +#[derive(Debug, Clone)] +struct PyDiscoveryService { + sender: Option>, +} + +#[allow(clippy::expect_used)] +impl PyDiscoveryService { + const fn sender(&self) -> &mpsc::Sender { + self.sender + .as_ref() + .expect("The sender should only be None after de-initialization.") + } + + const fn sender_mut(&mut self) -> &mut mpsc::Sender { + self.sender + .as_mut() + .expect("The sender should only be None after de-initialization.") + } + + const fn new(sender: mpsc::Sender) -> Self { + Self { + sender: Some(sender), + } + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyDiscoveryService { + #[new] + fn py_new<'py>(identity: Bound<'py, PyKeypair>) -> PyResult { + use pyo3_async_runtimes::tokio::get_runtime; + + // create communication channel + let (sender, receiver) = mpsc::channel::(MPSC_CHANNEL_SIZE); + + // get identity + let identity = identity.borrow().0.clone(); + + // create discovery swarm (within tokio context!! or it crashes) + let swarm = get_runtime() + .block_on(async { discovery_swarm(identity) }) + .pyerr()?; + + // spawn tokio task + get_runtime().spawn(async move { + discovery_task(receiver, swarm).await; + }); + Ok(Self::new(sender)) + } + + #[allow(clippy::expect_used)] + fn add_connected_callback<'py>( + &self, + #[override_type(type_repr="collections.abc.Callable[[ConnectionUpdate], None]", imports=("collections.abc"))] + callback: PyObject, + ) -> PyResult<()> { + use pyo3_async_runtimes::tokio::get_runtime; + + get_runtime() + .block_on( + self.sender() + .send(IncomingDiscoveryMessage::AddConnectedCallback(Box::new( + move |connection_event| { + Python::with_gil(|py| { + callback + .call1( + py, + (PyConnectionUpdate::from_connection_event( + connection_event, + ),), + ) + .expect("Callback should always work..."); + }); + }, + ))), + ) + .pyerr()?; + Ok(()) + } + + #[allow(clippy::expect_used)] + fn add_disconnected_callback<'py>( + &self, + #[override_type(type_repr="collections.abc.Callable[[ConnectionUpdate], None]", imports=("collections.abc"))] + callback: PyObject, + ) -> PyResult<()> { + use pyo3_async_runtimes::tokio::get_runtime; + + get_runtime() + .block_on( + self.sender() + .send(IncomingDiscoveryMessage::AddDisconnectedCallback(Box::new( + move |connection_event| { + Python::with_gil(|py| { + callback + .call1( + py, + (PyConnectionUpdate::from_connection_event( + connection_event, + ),), + ) + .expect("Callback should always work..."); + }); + }, + ))), + ) + .pyerr()?; + Ok(()) + } + + #[gen_stub(skip)] + const fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + Ok(()) // This is needed purely so `__clear__` can work + } + + #[gen_stub(skip)] + fn __clear__(&mut self) { + // TODO: may or may not need to await a "kill-signal" oneshot channel message, + // to ensure that the discovery task is done BEFORE exiting the clear function... + // but this may require GIL?? and it may not be safe to call GIL here?? + self.sender = None; // Using Option as a trick to force `sender` channel to be dropped + } +} + +pub fn discovery_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/src/lib.rs b/rust/exo_pyo3_bindings/src/lib.rs new file mode 100644 index 00000000..f1eed2c7 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/lib.rs @@ -0,0 +1,101 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation +//! +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +#![feature(tuple_trait)] +#![feature(unboxed_closures)] +// #![feature(stmt_expr_attributes)] +// #![feature(assert_matches)] +// #![feature(async_fn_in_dyn_trait)] +// #![feature(async_for_loop)] +// #![feature(auto_traits)] +// #![feature(negative_impls)] + +extern crate core; +pub(crate) mod discovery; +pub(crate) mod pylibp2p; + +use crate::discovery::discovery_submodule; +use crate::pylibp2p::connection::connection_submodule; +use crate::pylibp2p::ident::ident_submodule; +use crate::pylibp2p::multiaddr::multiaddr_submodule; +use pyo3::prelude::{PyModule, PyModuleMethods}; +use pyo3::{prelude::*, types::*}; +use pyo3::{pyclass, pymodule, Bound, PyResult}; +use pyo3_stub_gen::define_stub_info_gatherer; + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias { + use std::error::Error; + use std::marker::Tuple; + + pub trait SendFn = + Fn + Send + 'static; + + pub type AnyError = Box; + pub type AnyResult = Result; +} + +/// Namespace for crate-wide extension traits/methods +pub(crate) mod ext { + use extend::ext; + use pyo3::exceptions::PyRuntimeError; + use pyo3::PyErr; + + #[ext(pub, name = ResultExt)] + impl Result + where + E: ToString, + { + fn pyerr(self) -> Result { + self.map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + } +} + +pub(crate) mod private { + use std::marker::Sized; + + /// Sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} + +pub(crate) const MPSC_CHANNEL_SIZE: usize = 8; + +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule(name = "exo_pyo3_bindings")] +fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + // install logger + pyo3_log::init(); + + // TODO: for now this is all NOT a submodule, but figure out how to make the submodule system + // work with maturin, where the types generate correctly, in the right folder, without + // too many importing issues... + connection_submodule(m)?; + ident_submodule(m)?; + multiaddr_submodule(m)?; + discovery_submodule(m)?; + + // top-level constructs + // TODO: ... + + Ok(()) +} + +define_stub_info_gatherer!(stub_info); + +/// Test of unit test for testing link problem +#[cfg(test)] +mod tests { + #[test] + fn test() { + assert_eq!(2 + 2, 4); + } +} diff --git a/rust/exo_pyo3_bindings/src/pylibp2p/connection.rs b/rust/exo_pyo3_bindings/src/pylibp2p/connection.rs new file mode 100644 index 00000000..ac6c0125 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/pylibp2p/connection.rs @@ -0,0 +1,36 @@ +use libp2p::swarm::ConnectionId; +use pyo3::prelude::{PyModule, PyModuleMethods}; +use pyo3::{pyclass, pymethods, Bound, PyResult}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +/// TODO: documentation... +#[gen_stub_pyclass] +#[pyclass(name = "ConnectionId")] +#[derive(Debug, Clone)] +#[repr(transparent)] +pub struct PyConnectionId(pub ConnectionId); + +#[gen_stub_pymethods] +#[pymethods] +#[allow(clippy::needless_pass_by_value)] +impl PyConnectionId { + /// TODO: documentation + #[staticmethod] + fn new_unchecked(id: usize) -> Self { + Self(ConnectionId::new_unchecked(id)) + } + + fn __repr__(&self) -> String { + format!("ConnectionId({})", self.0) + } + + fn __str__(&self) -> String { + self.0.to_string() + } +} + +pub fn connection_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/src/pylibp2p/ident.rs b/rust/exo_pyo3_bindings/src/pylibp2p/ident.rs new file mode 100644 index 00000000..73239cca --- /dev/null +++ b/rust/exo_pyo3_bindings/src/pylibp2p/ident.rs @@ -0,0 +1,130 @@ +use crate::ext::ResultExt; +use libp2p::identity::{ecdsa, Keypair}; +use libp2p::PeerId; +use pyo3::prelude::{PyBytesMethods, PyModule, PyModuleMethods}; +use pyo3::types::PyBytes; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +/// TODO: documentation... +#[gen_stub_pyclass] +#[pyclass(name = "Keypair")] +#[repr(transparent)] +pub struct PyKeypair(pub Keypair); + +#[gen_stub_pymethods] +#[pymethods] +#[allow(clippy::needless_pass_by_value)] +impl PyKeypair { + /// TODO: documentation + #[staticmethod] + fn generate_ed25519() -> Self { + Self(Keypair::generate_ed25519()) + } + + /// TODO: documentation + #[staticmethod] + fn generate_ecdsa() -> Self { + Self(Keypair::generate_ecdsa()) + } + + /// TODO: documentation + #[staticmethod] + fn generate_secp256k1() -> Self { + Self(Keypair::generate_secp256k1()) + } + + /// TODO: documentation + #[staticmethod] + fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult { + let bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?)) + } + + /// TODO: documentation + #[staticmethod] + fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult { + let mut bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?)) + } + + /// TODO: documentation + #[staticmethod] + fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult { + let mut bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?)) + } + + /// TODO: documentation + #[staticmethod] + fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult { + let mut bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?)) + } + + /// TODO: documentation + #[staticmethod] + fn ecdsa_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult { + let bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::from(ecdsa::Keypair::from( + ecdsa::SecretKey::try_from_bytes(bytes).pyerr()?, + )))) + } + + /// TODO: documentation + fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult> { + let bytes = self.0.to_protobuf_encoding().pyerr()?; + Ok(PyBytes::new(py, &bytes)) + } +} + +/// TODO: documentation... +#[gen_stub_pyclass] +#[pyclass(name = "PeerId")] +#[derive(Debug, Clone)] +#[repr(transparent)] +pub struct PyPeerId(pub PeerId); + +#[gen_stub_pymethods] +#[pymethods] +#[allow(clippy::needless_pass_by_value)] +impl PyPeerId { + /// TODO: documentation + #[staticmethod] + fn random() -> Self { + Self(PeerId::random()) + } + + /// TODO: documentation + #[staticmethod] + fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult { + let bytes = Vec::from(bytes.as_bytes()); + Ok(Self(PeerId::from_bytes(&bytes).pyerr()?)) + } + + /// TODO: documentation + fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + let bytes = self.0.to_bytes(); + PyBytes::new(py, &bytes) + } + + /// TODO: documentation + fn to_base58(&self) -> String { + self.0.to_base58() + } + + fn __repr__(&self) -> String { + format!("PeerId({})", self.to_base58()) + } + + fn __str__(&self) -> String { + self.to_base58() + } +} + +pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/src/pylibp2p/mod.rs b/rust/exo_pyo3_bindings/src/pylibp2p/mod.rs new file mode 100644 index 00000000..ba8e358d --- /dev/null +++ b/rust/exo_pyo3_bindings/src/pylibp2p/mod.rs @@ -0,0 +1,3 @@ +pub mod connection; +pub mod ident; +pub mod multiaddr; diff --git a/rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs b/rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs new file mode 100644 index 00000000..38f555f4 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs @@ -0,0 +1,59 @@ +use libp2p::Multiaddr; +use pyo3::prelude::{PyModule, PyModuleMethods}; +use pyo3::types::PyBytes; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +/// TODO: documentation... +#[gen_stub_pyclass] +#[pyclass(name = "Multiaddr")] +#[derive(Debug, Clone)] +#[repr(transparent)] +pub struct PyMultiaddr(pub Multiaddr); + +#[gen_stub_pymethods] +#[pymethods] +#[allow(clippy::needless_pass_by_value)] +impl PyMultiaddr { + /// TODO: documentation + #[staticmethod] + fn empty() -> Self { + Self(Multiaddr::empty()) + } + + /// TODO: documentation + #[staticmethod] + fn with_capacity(n: usize) -> Self { + Self(Multiaddr::with_capacity(n)) + } + + /// TODO: documentation + fn len(&self) -> usize { + self.0.len() + } + + /// TODO: documentation + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// TODO: documentation + fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + let bytes = self.0.to_vec(); + PyBytes::new(py, &bytes) + } + + fn __repr__(&self) -> String { + format!("Multiaddr({})", self.0) + } + + fn __str__(&self) -> String { + self.0.to_string() + } +} + +pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/tests/dummy.rs b/rust/exo_pyo3_bindings/tests/dummy.rs new file mode 100644 index 00000000..7d1ce0e4 --- /dev/null +++ b/rust/exo_pyo3_bindings/tests/dummy.rs @@ -0,0 +1,54 @@ +#[cfg(test)] +mod tests { + use core::mem::drop; + use core::option::Option::Some; + use core::time::Duration; + use tokio; + use tokio::sync::mpsc; + + #[tokio::test] + async fn test_drop_channel() { + struct Ping; + + let (tx, mut rx) = mpsc::channel::(10); + + let _ = tokio::spawn(async move { + println!("TASK: entered"); + + loop { + tokio::select! { + result = rx.recv() => { + match result { + Some(_) => { + println!("TASK: pinged"); + } + None => { + println!("TASK: closing channel"); + break; + } + } + } + _ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => { + println!("TASK: heartbeat"); + } + } + } + + println!("TASK: exited"); + }); + + let tx2 = tx.clone(); + + tokio::time::sleep(Duration::from_secs_f32(0.11)).await; + + tx.send(Ping).await.expect("Should not fail"); + drop(tx); + + tokio::time::sleep(Duration::from_secs_f32(0.11)).await; + + tx2.send(Ping).await.expect("Should not fail"); + drop(tx2); + + tokio::time::sleep(Duration::from_secs_f32(0.11)).await; + } +} diff --git a/rust/exo_pyo3_bindings/tests/test_python.py b/rust/exo_pyo3_bindings/tests/test_python.py new file mode 100644 index 00000000..d1408f45 --- /dev/null +++ b/rust/exo_pyo3_bindings/tests/test_python.py @@ -0,0 +1,72 @@ +import time +from collections.abc import Awaitable +from typing import Callable + +import pytest +from exo_pyo3_bindings import ConnectionUpdate, Keypair, DiscoveryService + + +# # => `tokio::mpsc` channels are closed when all `Sender` are dropped, or when `Receiver::close` is called +# # => the only sender is `KillableTaskHandle.sender: Option>>` +# # => integrate with https://pyo3.rs/v0.25.1/class/protocols.html#garbage-collector-integration +# # => set `sender` to `None` to drop the `Sender` & therefore trigger an automatic cleanup +# # => TODO: there could be a bug where dropping `Sender` won't close the channel in time bc of unprocessed events +# # so the handle drops and asyncio loop closes BEFORE the task dies... +# # might wanna figure out some kind of `oneshot` "shutdown confirmed" blocking mechanism or something...?? +# # => also there is "cancellable futures" stuff ?? => https://pyo3.rs/main/async-await.html +# # +# # For now, always explicitly call cleanup functions to avoid crashes +# # in the future research tighter integration for automatic cleanup and safety!!! +# # also look into `pyo3_async_runtimes::tokio::get_runtime()` for blocking calls in Rust +# @pytest.mark.asyncio +# async def test_handle_kill() -> None: +# print("PYTHON: starting handle") +# h: KillableTaskHandle = killable_task_spawn() + +# time.sleep(0.35) + +# # for i in range(0, 4): +# # print(f"PYTHON: waiting... {i}") +# # time.sleep(0.11) + +# # print("PYTHON: killing task") +# # h.kill_task() + +# def test_keypair_creation() -> None: +# kp = Keypair.generate_ecdsa() +# kp_protobuf = kp.to_protobuf_encoding() +# print(kp_protobuf) +# kp = Keypair.from_protobuf_encoding(kp_protobuf) +# assert kp.to_protobuf_encoding() == kp_protobuf + + +@pytest.mark.asyncio +async def test_discovery_callbacks() -> None: + ident = Keypair.generate_ed25519() + + service = DiscoveryService(ident) + service.add_connected_callback(add_connected_callback) + service.add_disconnected_callback(disconnected_callback) + + for i in range(0, 10): + print(f"PYTHON: tick {i} of 10") + time.sleep(1) + + pass + + +def add_connected_callback(e: ConnectionUpdate) -> None: + print(f"\n\nPYTHON: Connected callback: {e.peer_id}, {e.connection_id}, {e.local_addr}, {e.send_back_addr}") + print( + f"PYTHON: Connected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n") + + +def disconnected_callback(e: ConnectionUpdate) -> None: + print(f"\n\nPYTHON: Disconnected callback: {e.peer_id}, {e.connection_id}, {e.local_addr}, {e.send_back_addr}") + print( + f"PYTHON: Disconnected callback: {e.peer_id.__repr__()}, {e.connection_id.__repr__()}, {e.local_addr.__repr__()}, {e.send_back_addr.__repr__()}\n\n") + + +async def foobar(a: Callable[[str], Awaitable[str]]): + abc = await a("") + pass diff --git a/rust/master_election/Cargo.toml b/rust/master_election/Cargo.toml new file mode 100644 index 00000000..c5164f50 --- /dev/null +++ b/rust/master_election/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "master_election" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +name = "master_election" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +# macro dependencies +extend = { workspace = true } +delegate = { workspace = true } +impl-trait-for-tuples = { workspace = true } +derive_more = { workspace = true } + +# Async +tokio = { workspace = true, features = ["full"] } +futures = { workspace = true } + +# utility dependencies +#util = { workspace = true } +#fn_pipe = { workspace = true } +thiserror = { workspace = true } +#internment = { workspace = true } +#recursion = { workspace = true } +#generativity = { workspace = true } +#itertools = { workspace = true } +tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] } +keccak-const = { workspace = true } + +# Data types +ordered-float = { workspace = true } + +# Networking +libp2p = { workspace = true, features = ["full"] } \ No newline at end of file diff --git a/rust/master_election/src/cel/centrality.rs b/rust/master_election/src/cel/centrality.rs new file mode 100644 index 00000000..2042d384 --- /dev/null +++ b/rust/master_election/src/cel/centrality.rs @@ -0,0 +1,36 @@ +use crate::cel::data::Map; +use crate::cel::{View, ID}; + +/// The number of neighbours of a process. +pub fn degree_centrality(known: &Map, id: ID) -> u32 { + todo!() +} + +/// Measures average length of the shortest path between the vertex and all other vertices in the graph. +/// The more central is a vertex, the closer it is to all other vertices. The closeness centrality +/// characterizes the ability of a node to spread information over the graph. +/// +/// Alex Balevas defined in 1950 the closeness centrality of a vertex as follows: +/// `C_C(x) = \frac{1}{ \sum_y d(x,y) }` where `d(x,y)` is the shortest path between `x` and `y`. +/// +/// CEL paper uses this. +pub fn closeness_centrality(known: &Map, id: ID) -> u32 { + todo!() +} + +/// Measures the number of times a vertex acts as a relay (router) along +/// shortest paths between other vertices. Even if previous authors +/// have intuitively described centrality as being based on betweenness, +/// betweenness centrality was formally defined by Freeman in 1977. +/// +/// The betweenness of a vertex `x` is defined as the sum, for each pair +/// of vertices `(s, t)`, of the number of shortest paths from `s` to `t` that +/// pass through `x`, over the total number of shortest paths between +/// vertices `s` and `t`; it can be represented by the following formula: +/// `C_B(x) = \sum_{ s \neq x \neq t } \frac{ \sigma_{st}(x) }{ \sigma_{st} }` +/// where `\sigma_{st}` denotes the total number of shortest paths from vertex `s` +/// to vertex `t` (with `\sigma_{ss} = 1` by convention), and `\sigma_{st}(x)` +/// is the number of those shorter paths that pass through `x`. +pub fn betweenness_centrality(known: &Map, id: ID) -> u32 { + todo!() +} diff --git a/rust/master_election/src/cel/messaging.rs b/rust/master_election/src/cel/messaging.rs new file mode 100644 index 00000000..4cac6dd1 --- /dev/null +++ b/rust/master_election/src/cel/messaging.rs @@ -0,0 +1,57 @@ +use crate::cel::messaging::data::Probability; +use crate::cel::KnowledgeMessage; + +mod data { + use ordered_float::OrderedFloat; + use thiserror::Error; + + #[derive(Error, Debug, Copy, Clone, PartialEq, PartialOrd)] + #[error("Floating number `{0}` is not a probability")] + #[repr(transparent)] + pub struct NotProbabilityError(f64); + + #[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] + #[repr(transparent)] + pub struct Probability(OrderedFloat); + + impl Probability { + const MIN_P: OrderedFloat = OrderedFloat(0.0); + const MAX_P: OrderedFloat = OrderedFloat(1.0); + + pub fn new(p: f64) -> Result { + let p = OrderedFloat(p); + if Self::MIN_P <= p && p <= Self::MAX_P { + Ok(Self(p)) + } else { + Err(NotProbabilityError(p.0)) + } + } + + pub const fn into_f64(self) -> f64 { + self.0.0 + } + } + + impl From for f64 { + fn from(value: Probability) -> Self { + value.into_f64() + } + } + + impl TryFrom for Probability { + type Error = NotProbabilityError; + fn try_from(value: f64) -> Result { + Self::new(value) + } + } +} + +/// Haas et al. proposed several gossip protocols for *ad hoc networks* that use probabilities. +/// Combined with the number of hops or the number of times the same message is received, the +/// protocols choose if a node broadcast a message to all its neighbors or not, reducing thus +/// the number of messages propagated in the system. The authors show that gossiping with a +/// probability between 0.6 and 0.8 ensures that almost every node of the system gets the message, +/// with up to 35% fewer messages in some networks compared to flooding. +pub fn local_broadcast(message: KnowledgeMessage, rho: Probability) { + // +} diff --git a/rust/master_election/src/cel/mod.rs b/rust/master_election/src/cel/mod.rs new file mode 100644 index 00000000..b7856d28 --- /dev/null +++ b/rust/master_election/src/cel/mod.rs @@ -0,0 +1,333 @@ +pub mod centrality; +pub mod messaging; + +use crate::cel::data::{Map, Set}; +use std::collections::VecDeque; + +pub mod data { + use std::marker::PhantomData; + + #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] + pub struct Set(PhantomData); + + impl Set { + pub fn new() -> Self { + todo!() + } + + pub fn add(&mut self, value: V) -> bool { + todo!() + } + + pub fn remove(&mut self, value: V) {} + + pub fn add_all(&mut self, other: &Set) {} + + pub fn values_mut(&mut self) -> &mut [V] { + todo!() + } + + pub fn values(&self) -> &[V] { + todo!() + } + } + + #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] + pub struct Map(PhantomData<(K, V)>); + + impl Map { + pub fn new() -> Self { + todo!() + } + + pub fn set(&mut self, key: K, value: V) {} + + pub fn get(&self, key: K) -> &V { + todo!() + } + + pub fn get_mut(&mut self, key: K) -> &mut V { + todo!() + } + + pub fn kv_mut(&mut self) -> &mut [(K, V)] { + todo!() + } + + pub fn contains_key(&self, key: K) -> bool { + todo!() + } + + pub fn not_contains_key(&self, key: K) -> bool { + !self.contains_key(key) + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +#[repr(transparent)] +pub struct ID(pub u128); + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +#[repr(transparent)] +pub struct Clock(pub u64); + +impl Clock { + pub const ZERO: Self = Self(0); + pub const ONE: Self = Self(1); + + pub fn plus_one(self) -> Self { + Self(self.0 + 1) + } +} + +/// `CEL` uses a data structure called a `view` +/// +/// A `view` associated to node is composed of two elements: +/// 1) A logical `clock` value, acting as a timestamp and incremented at each connection and disconnection. +/// 2) A set of node `identifiers`, which are the current neighbors of `i` (this node). +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +pub struct View { + /// Logical clock + clock: Clock, + + /// Neighbors set + neigh: Set, +} + +impl View { + pub fn new(clock: Clock, neigh: Set) -> Self { + Self { clock, neigh } + } +} + +/// The only type of message exchanged between neighbors is the `knowledge` message. +/// It contains the current topological knowledge that the sender node has of the network, +/// i.e. its `known` variable. +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +pub struct KnowledgeMessage { + pub known: Map, +} + +/// Each node `i` maintains a local variable called `known`. +/// +/// This variable represents the current topological knowledge that `i` has of its current +/// component (including itself). It is implemented as a map of `view` indexed by node `identifier`. +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +pub struct Node { + id: ID, + known: Map, +} + +impl Node { + /// Firstly, node initializes its `known` variable with its own identifier (`i`), + /// and sets its logical clock to `0`. + pub fn initialization(this_id: ID) -> Self { + let mut neigh = Set::new(); // neigh = \{ i \} + neigh.add(this_id); + + let mut known = Map::::new(); + known.set(this_id, View::new(Clock::ZERO, neigh)); + + Self { id: this_id, known } + } + + /// When a new node `j` appears in the transmission range of `i`, the crosslayer mechanism of + /// `i` detects `j`, and triggers the `Connection` method. + /// + /// Node `j` is added to the neighbors set of node `i`. As the knowledge of has been updated, + /// its logical clock is incremented. + /// + /// Since links are assumed bidirectional, i.e. the emission range equals the reception range, + /// if node `i` has no previous knowledge of `j`, the neighbor-aware mechanism adds both + /// `i` and `j` in the set of neighbors of `j`. Then, `i` sets the clock value of `j` to `1`, + /// as `i` was added to the knowledge of node `j`. On the other hand, if node `i` already has + /// information about `j`, `i` is added to the neighbors of `j`, and the logical clock of + /// node `j` is incremented. + /// + /// Finally, by calling `LocalBroadcast` method, node `i` shares its + /// knowledge with `j` and informs its neighborhood of its new neighbor `j`. + /// Note that such a method sends a knowledge message to the neighbors + /// of node `i`, with a gossip probability `\rho`, as seen in `Section 2.8`. + /// However, for the first hop, `\rho` is set to `1` to make sure that all neighbors of `i` + /// will be aware of its new neighbor `j`. Note that the cross-layer mechanism + /// of node `j` will also trigger its `Connection` method, and the respective + /// steps will also be achieved on node `j`. + pub fn node_connection(&mut self, other_id: ID) { + let this_known = self.known.get_mut(self.id); + this_known.neigh.add(other_id); + this_known.clock = this_known.clock.plus_one(); + + if self.known.not_contains_key(other_id) { + let mut other_neigh = Set::new(); // neigh = \{ j, i \} + other_neigh.add(self.id); + other_neigh.add(other_id); + + self.known.set(other_id, View::new(Clock::ONE, other_neigh)); + } else { + let other_known = self.known.get_mut(other_id); + other_known.neigh.add(self.id); + other_known.clock = other_known.clock.plus_one(); + } + + // TODO: `LocalBroadcast(knowlege, 1)` + } + + /// When a node `j` disappears from the transmission range of node `i`, + /// the cross-layer mechanism stops receiving beacon messages at the + /// MAC level, and triggers the `Disconnection` method. Node `j` is + /// then removed from the knowledge of node `i`, and its clock + /// is incremented as its knowledge was modified. + /// + /// Then, the neighbor-aware mechanism assumes that node `i` will also disconnect + /// from `j`. Therefore, `i` is removed from the neighborhood of `j` in the + /// knowledge of node `i`, and the corresponding clock is incremented. + /// + /// Finally, node `i` broadcasts its updated knowledge to its neighbors. + pub fn node_disconected(&mut self, other_id: ID) { + let this_known = self.known.get_mut(self.id); + this_known.neigh.remove(other_id); + this_known.clock = this_known.clock.plus_one(); + + let other_known = self.known.get_mut(other_id); + other_known.neigh.remove(self.id); + other_known.clock = other_known.clock.plus_one(); + + // TODO: `LocalBroadcast(knowlege, 1)` + } + + /// When node receives a knowledge message `known_j`, from node `j`, + /// it looks at each node `n` included in `known_j`. If `n` is an + /// unknown node for `i`, or if `n` is known by node `i` and has a + /// more recent clock value in `known_j`, the clock and neighbors of + /// node `n` are updated in the knowledge of `i`. + /// + /// Note that a clock value of `n` higher than the one currently known by + /// node `i` means that node `n` made some connections and/or + /// disconnections of which node `i` is not aware. Then, the `UpdateNeighbors` + /// method is called to update the knowledge of `i` regarding the neighbors + /// of `n`. If the clock value of node `n` is identical to the one of + /// both the knowledge of node `i` and `known_j`, the neighbor-aware + /// mechanism merges the neighbors of node `n` from `known_j` with the + /// known neighbors of `n` in the knowledge of `i`. + /// + /// Remark that the clock of node `n` is not updated by the neighbor-aware + /// mechanism, otherwise, `n` would not be able to override this view in the + /// future with more recent information. The `UpdateNeighbors` method is + /// then called. Finally, node `i` broadcasts its knowledge only if + /// this latter was modified. + pub fn receive_knowledge( + &mut self, + other_id: ID, + KnowledgeMessage { + known: mut other_known, + }: KnowledgeMessage, + ) { + let mut this_known_updated = false; + + for (n, other_known_n) in other_known.kv_mut() { + if self.known.not_contains_key(*n) || other_known_n.clock > self.known.get(*n).clock { + self.known.set(*n, other_known_n.clone()); + // TODO: UpdateNeighbors(known_j, n) + } else if other_known_n.clock == self.known.get(*n).clock { + self.known.get_mut(*n).neigh.add_all(&other_known_n.neigh); + // TODO: UpdateNeighbors(known_j, n) + } + } + + // TODO: figure out what constitutes "updated", i.e. should any of the two branches count? + // or should each atomic update-op be checked for "change"?? + if this_known_updated { + // TODO: TopologicalBroadcast() + } + } + + /// The `UpdateNeighbors` method updates the knowledge of `i` with + /// information about the neighbors of node `n`. If the neighbor `k` + /// is an unknown node for `i`, or if `k` is known by `i` but has a + /// more recent clock value in `known_j` (line 38), the clock and neighbors + /// of node `k` are added or updated in the knowledge of node `i`. + /// Otherwise, if the clock of node `k` is identical in the knowledge of node + /// `i` and in `known_j`, the neighbor-aware mechanism merges the + /// neighbors of node `k` in the knowledge of `i`. + fn update_neighbors(&mut self, other_known: &mut Map, n: ID) { + for k in other_known.get(n).neigh.values() { + if self.known.not_contains_key(*k) + || other_known.get(*k).clock > self.known.get(*k).clock + { + self.known.set(*k, other_known.get(*k).clone()); + } else if other_known.get(*k).clock == self.known.get(*k).clock { + self.known + .get_mut(*k) + .neigh + .add_all(&other_known.get(*k).neigh); + } + } + } + + /// The `TopologicalBroadcast` method uses a self-pruning approach to broadcast + /// or not the updated knowledge of node `i`, after the reception of a `knowledge` + /// from a neighbor `j`. To this end, node `i` checks whether each of its neighbors + /// has the same neighborhood as itself. In this case, node `n` is supposed to have + /// also received the knowledge message from neighbor node `j`. Therefore, among the + /// neighbors having the same neighborhood than `i`, only the one with + /// the smallest identifier will broadcast the knowledge, with a + /// gossip probability `\rho`. Note that this topological self-pruning + /// mechanism reaches the same neighborhood as multiple broadcasts. + fn topological_broadcast(&self) { + for n in self.known.get(self.id).neigh.values() { + // TODO: ensure this is a value-equality comparison + if self.known.get(*n).neigh == self.known.get(self.id).neigh { + if *n < self.id { + return; + } + } + } + + // TODO: `LocalBroadcast(knowlege, \rho)` + } + + /// The leader is elected when a process running on node `i` calls the `Leader` + /// function. This function returns the most central leader in the component + /// according the closeness centrality, as seen in Section 2.7, using the + /// knowledge of node `i`. The closeness centrality is chosen instead of the + /// betweenness centrality, because it is faster to compute and requires fewer + /// computational steps, therefore consuming less energy from the mobile node + /// batteries than the latter. + /// + /// First, node `i` rebuilds its component according to its topological knowledge. + /// To do so, it computes the entire set of reachable nodes, by adding + /// neighbors, neighbors of its neighbors, and so on. + /// Then, it evaluates the shortest distance between each reachable node and the + /// other ones, and computes the closeness centrality for each of them. + /// Finally, it returns the node having the highest closeness value as the + /// leader. The highest node identifier is used to break ties among + /// identical centrality values. If all nodes of the component have the same + /// topological knowledge, the `Leader()` function will return the same leader + /// node when invoked. Otherwise, it may return different leader nodes. + /// However, when the network topology stops changing, the algorithm + /// ensures that all nodes of a component will eventually have the same + /// topological knowledge and therefore, the `Leader()` function will return + /// the same leader node for all of them. + fn leader(&self) -> ID { + // this just computes the transitive closure of the adj-list graph starting from node `i` + // TODO: its an inefficient BFS impl, swap to better later!!! + let mut component = Set::new(); + + let mut process_queue = + VecDeque::from_iter(self.known.get(self.id).neigh.values().iter().cloned()); + while let Some(j) = process_queue.pop_front() { + let successfully_added = component.add(j); + + // was already processed, so don't add neighbors + if !successfully_added { + continue; + } + + process_queue.extend(self.known.get(j).neigh.values().iter().cloned()); + } + + let leader: ID = todo!(); // TODO: `Max (ClosenessCentrality (component))` + return leader; + } +} diff --git a/rust/master_election/src/communicator.rs b/rust/master_election/src/communicator.rs new file mode 100644 index 00000000..7913ad8d --- /dev/null +++ b/rust/master_election/src/communicator.rs @@ -0,0 +1,35 @@ +//! Communicator is an abstraction that allows me to "mock" speaking to the network +//! + +use crate::participant::{Participant, ParticipantId}; +use crate::ElectionMessage; + +pub trait Communicator { + fn all_participants(&self) -> &[ParticipantId]; + fn broadcast_message(&self, message: ElectionMessage, recipients: &[ParticipantId]) -> (); + fn register_participant(&mut self, participant: &Participant) -> ParticipantId; +} + +mod communicator_impls { + macro_rules! as_ref_impl { + () => { + #[inline] + fn all_participants(&self) -> &[ParticipantId] { + self.as_ref().all_participants() + } + + #[inline] + fn broadcast_message(&self, message: Message, recipients: &[ParticipantId]) { + self.as_ref().broadcast_message(message, recipients); + } + }; + } + + // impl Communicator for Box { + // as_ref_impl!(); + // } + // + // impl Communicator for Arc { + // as_ref_impl!(); + // } +} diff --git a/rust/master_election/src/lib.rs b/rust/master_election/src/lib.rs new file mode 100644 index 00000000..221f15d8 --- /dev/null +++ b/rust/master_election/src/lib.rs @@ -0,0 +1,44 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation +//! +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +// #![feature(stmt_expr_attributes)] +// #![feature(unboxed_closures)] +// #![feature(assert_matches)] +// #![feature(async_fn_in_dyn_trait)] +// #![feature(async_for_loop)] +// #![feature(auto_traits)] +// #![feature(negative_impls)] + +use crate::participant::ParticipantId; + +pub mod cel; +mod communicator; +mod participant; + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias {} + +/// Namespace for crate-wide extension traits/methods +pub(crate) mod ext {} + +pub(crate) mod private { + /// Sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} + +pub enum ElectionMessage { + /// Announce election + Election { + candidate: ParticipantId, + }, + Alive, + Victory { + coordinator: ParticipantId, + }, +} diff --git a/rust/master_election/src/participant.rs b/rust/master_election/src/participant.rs new file mode 100644 index 00000000..f027d9e4 --- /dev/null +++ b/rust/master_election/src/participant.rs @@ -0,0 +1,203 @@ +use crate::communicator::Communicator; +use crate::ElectionMessage; +use std::sync::Arc; +use std::time::Duration; +use thiserror::Error; +use tokio::sync::{mpsc, Mutex}; + +// trait ParticipantState {} // TODO: make sealed or something?? +// +// struct Coordinator; // TODO: change to master +// struct Candidate; // i.e. election candidate +// struct Transient; // transient state, e.g. waiting for election results, declaring themselves winner, etc +// struct Follower; // i.e. a follower of an existing coordinator +// +// mod participant_impl { +// use crate::participant::{Candidate, Coordinator, Follower, ParticipantState, Transient}; +// +// impl ParticipantState for Coordinator {} +// impl ParticipantState for Candidate {} +// impl ParticipantState for Transient {} +// impl ParticipantState for Follower {} +// } + +pub type ParticipantSelf = Arc>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct ParticipantId(pub u128); + +#[derive(Debug, Clone, Copy)] +pub enum ParticipantState { + Coordinator, // i.e. master + ElectionCandidate, // after noticing a master went down, become candidate and `Election` message to all nodes higher than itself + Waiting, // when lower nodes are waiting for results of an election to conclude + Follower { id: ParticipantId }, // when a participant is following a coordinator + Transient, // when the participant is in a neutral/uninitialized state +} + +pub struct Participant { + id: ParticipantId, + state: ParticipantState, + on_message_sent: Vec>, +} + +mod impls { + use crate::participant::{Participant, ParticipantId, ParticipantSelf, ParticipantState}; + use crate::ElectionMessage; + + impl Participant { + pub fn new_with(id: ParticipantId, state: ParticipantState) -> Self { + Self { + id, + state, + on_message_sent: vec![], + } + } + + pub fn add_on_message_sent(&mut self, callback: F) + where + F: FnOnce(ElectionMessage, ParticipantId) + Send + 'static, + { + self.on_message_sent.push(Box::new(callback)); + } + + pub async fn receive_message(mut self_: ParticipantSelf, message: ElectionMessage) { + let foo = self_.lock_owned().await; + } + } +} + +pub const TASK_CHANNEL_SIZE: usize = 8; +pub const ELECTION_VICTORY_TIMEOUT: Duration = Duration::from_secs(1); +pub const VICTORY_WAITING_TIMEOUT: Duration = Duration::from_secs(1); +pub const HEARTBEAT_RECEIVE_TIMEOUT: Duration = Duration::from_secs(2); +pub const HEARTBEAT_SEND_TIMEOUT: Duration = Duration::from_secs(1); + +pub enum InMessage { + ElectionMessage(ElectionMessage), + Heartbeat, +} + +pub enum OutMessage { + ElectionMessage(ElectionMessage), + Heartbeat, +} + +#[derive(Error, Debug)] +pub enum ParticipantError { + #[error("could not send out-message: `{0}`")] + SendError(#[from] mpsc::error::SendError), +} + +pub async fn participant_task( + mut in_channel: mpsc::Receiver, + out_channel: mpsc::Sender, + communicator: C, +) -> Result<(), ParticipantError> { + // task state + let participant_id: ParticipantId = ParticipantId(1234u128); // TODO: replace with dependency injection + let mut participant_state: ParticipantState = ParticipantState::Transient; + + // TODO: slot this logic into this somewhere... + // 4. If P receives an Election message from another process with a lower ID it sends an Answer message + // back and if it has not already started an election, it starts the election process at the beginning, + // by sending an Election message to higher-numbered processes. + + loop { + match participant_state { + ParticipantState::Transient => { + // When a process P recovers from failure, or the failure detector indicates + // that the current coordinator has failed, P performs the following actions: + // + // 1A) If P has the highest process ID, it sends a Victory message to all other + // processes and becomes the new Coordinator. + let max_id = communicator + .all_participants() + .iter() + .max() + .unwrap_or(&ParticipantId(0u128)); + if max_id <= &participant_id { + participant_state = ParticipantState::Coordinator; + communicator.broadcast_message( + ElectionMessage::Victory { + coordinator: participant_id, + }, + communicator.all_participants(), + ); + continue; + } + + // 1B) Otherwise, P broadcasts an Election message to all other processes with + // higher process IDs than itself + participant_state = ParticipantState::ElectionCandidate; + communicator.broadcast_message( + ElectionMessage::Election { + candidate: participant_id, + }, + &communicator + .all_participants() + .iter() + .filter(|&p| p > &participant_id) + .copied() + .collect::>(), + ); + } + ParticipantState::ElectionCandidate => { + tokio::select! { + // 2. If P receives no Answer after sending an Election message, then it broadcasts + // a Victory message to all other processes and becomes the Coordinator. + _ = tokio::time::sleep(ELECTION_VICTORY_TIMEOUT) => { + participant_state = ParticipantState::Coordinator; + communicator.broadcast_message( + ElectionMessage::Victory { + coordinator: participant_id, + }, + communicator.all_participants(), + ); + } + + // 3A. If P receives an Answer from a process with a higher ID, it sends no further + // messages for this election and waits for a Victory message. (If there is no Victory + // message after a period of time, it restarts the process at the beginning.) + Some(InMessage::ElectionMessage(ElectionMessage::Alive)) = in_channel.recv() => { + participant_state = ParticipantState::Waiting; + } // TODO: handle all other branches, e.g. channel closure, different messages & so on + } + } + ParticipantState::Waiting => { + tokio::select! { + // 3B. If there is no Victory message after a period of time, it restarts the process + // at the beginning. + _ = tokio::time::sleep(VICTORY_WAITING_TIMEOUT) => { + participant_state = ParticipantState::Transient; + } + + // 5. If P receives a Victory message, it treats the sender as the coordinator. + Some(InMessage::ElectionMessage(ElectionMessage::Victory { coordinator })) = in_channel.recv() => { + participant_state = ParticipantState::Follower { id: coordinator }; + } // TODO: handle all other branches, e.g. channel closure, different messages & so on + } + } + ParticipantState::Follower { id: coordinator_id } => { + tokio::select! { + // If we do not receive a heartbeat from the coordinator, trigger new election + _ = tokio::time::sleep(VICTORY_WAITING_TIMEOUT) => { + participant_state = ParticipantState::Transient; + } + + // If we do receive a heartbeat - keep going + Some(InMessage::Heartbeat) = in_channel.recv() => { + } // TODO: handle all other branches, e.g. channel closure, different messages & so on + } + } + ParticipantState::Coordinator => { + // If we are coordinator - send heart beats + { + out_channel.send(OutMessage::Heartbeat).await?; + tokio::time::sleep(HEARTBEAT_SEND_TIMEOUT).await; + } + } + } + } +} diff --git a/rust/master_election/tests/dummy.rs b/rust/master_election/tests/dummy.rs new file mode 100644 index 00000000..d82c6eb1 --- /dev/null +++ b/rust/master_election/tests/dummy.rs @@ -0,0 +1,8 @@ +// maybe this will hold test in the future...?? + +#[cfg(test)] +mod tests { + #[test] + fn does_nothing() { + } +} \ No newline at end of file diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 00000000..271800cb --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" \ No newline at end of file diff --git a/rust/util/Cargo.toml b/rust/util/Cargo.toml new file mode 100644 index 00000000..b818252e --- /dev/null +++ b/rust/util/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "util" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +name = "util" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +# macro dependencies +extend = { workspace = true } + +# utility dependencies +thiserror = { workspace = true } +once_cell = { workspace = true } +internment = { workspace = true } +derive_more = { workspace = true } +bon = { workspace = true } +recursion = { workspace = true } +fn_pipe = { workspace = true } diff --git a/rust/util/fn_pipe/Cargo.toml b/rust/util/fn_pipe/Cargo.toml new file mode 100644 index 00000000..fed18ea1 --- /dev/null +++ b/rust/util/fn_pipe/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "fn_pipe" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +name = "fn_pipe" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +fn_pipe_proc = { workspace = true } \ No newline at end of file diff --git a/rust/util/fn_pipe/proc/Cargo.toml b/rust/util/fn_pipe/proc/Cargo.toml new file mode 100644 index 00000000..087d9500 --- /dev/null +++ b/rust/util/fn_pipe/proc/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "fn_pipe_proc" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +name = "fn_pipe_proc" +path = "src/lib.rs" +proc-macro = true + +[lints] +workspace = true + +[dependencies] +extend = { workspace = true } +syn = { workspace = true } +quote = { workspace = true } +proc-macro2 = { workspace = true } +darling = { workspace = true } diff --git a/rust/util/fn_pipe/proc/src/lib.rs b/rust/util/fn_pipe/proc/src/lib.rs new file mode 100644 index 00000000..3a471522 --- /dev/null +++ b/rust/util/fn_pipe/proc/src/lib.rs @@ -0,0 +1,201 @@ +//! Proc-macro for implementing `Fn/Pipe*` variants for tuples of a given size; +//! it is only here for this one purpose and no other, should not be used elsewhere + +#![allow(clippy::arbitrary_source_item_ordering)] + +extern crate proc_macro; + +use extend::ext; +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, LitInt}; + +type TokS2 = proc_macro2::TokenStream; + +#[allow( + clippy::unwrap_used, + clippy::indexing_slicing, + clippy::arithmetic_side_effects, + clippy::missing_panics_doc, + clippy::too_many_lines +)] +#[proc_macro] +pub fn impl_fn_pipe_for_tuple(item: TokenStream) -> TokenStream { + // DEFINE CONSTANT TOKEN STREAMS UPFRONT + // token streams for Fn/Pipe* variants + let fn_pipe_names = ( + ( + "Fn".parse_unchecked(), + "FnPipe".parse_unchecked(), + "run".parse_unchecked(), + "call".parse_unchecked(), + ), + ( + "FnMut".parse_unchecked(), + "FnMutPipe".parse_unchecked(), + "run_mut".parse_unchecked(), + "call_mut".parse_unchecked(), + ), + ( + "FnOnce".parse_unchecked(), + "FnOncePipe".parse_unchecked(), + "run_once".parse_unchecked(), + "call_once".parse_unchecked(), + ), + ); + + // get the number of tuple parameters to implement this for + let max_tuple_size = match parse_macro_input!(item as LitInt).base10_parse::() { + Ok(num) => num, + Err(e) => return e.to_compile_error().into(), + }; + assert!( + max_tuple_size > 0, + "passed parameter must be greater than zero" + ); + + // generate generic function type-names, to be used later everywhere + let mut fn_type_names = Vec::with_capacity(max_tuple_size); + for i in 0..max_tuple_size { + fn_type_names.push(format!("_{i}").parse_unchecked()); + } + + // create a middle type constraint (i.e. not the first one) + let middle_type_constraint = |prev_fn: TokS2, this_fn: TokS2, fn_name: TokS2| { + quote! { + #this_fn: #fn_name<(#prev_fn::Output,)> + } + }; + + // create call implementation + let impl_call = |n: usize, call: TokS2, base: TokS2| { + let tuple_access = format!("self.{n}").parse_unchecked(); + quote! { + #tuple_access.#call((#base,)) + } + }; + + // generic impl block parametrised on the variant and number of params + let impl_per_type_and_n = |n: usize, + (fn_name, fn_pipe_name, run, call): (TokS2, TokS2, TokS2, TokS2), + extra: Option, + ref_style: Option| { + // flatten the extra tokens + let extra = extra.unwrap_or_default(); + + let fn_type_names_comma_sep = &fn_type_names[0..n].comma_separated(); + + // get name of first type and create the type constraint for the fist type + let first_fn_type = fn_type_names[0].clone(); + let first_type_constraint = quote! { + #first_fn_type: #fn_name + }; + + // create the middle type constraint implementations + let middle_type_constraints = (1..n) + .map(|i| { + // get previous and current tokens + let prev_fn = fn_type_names[i - 1].clone(); + let this_fn = fn_type_names[i].clone(); + + // create middle implementation + middle_type_constraint(prev_fn, this_fn, fn_name.clone()) + }) + .collect::>(); + + // combine the two, and comma-separate them into a single block + let type_constraints = [vec![first_type_constraint], middle_type_constraints] + .concat() + .as_slice() + .comma_separated(); + + // recursive call implementation starting from the base + let mut call_impl = quote! { self.0 .#call(args) }; + for i in 1..n { + call_impl = impl_call(i, call.clone(), call_impl); + } + + quote! { + #[allow(clippy::type_repetition_in_bounds)] + impl #fn_pipe_name for (#fn_type_names_comma_sep,) + where #type_constraints + { + #extra + + #[inline] + extern "rust-call" fn #run(#ref_style self, args: Args) -> Self::Output { + #call_impl + } + } + } + }; + + // generic impl block parametrised on the number of params + let impl_per_n = |n: usize| { + // create the `Fn/FnPipe` implementation + let mut impl_per_n = + impl_per_type_and_n(n, fn_pipe_names.0.clone(), None, Some(quote! { & })); + + // create the `FnMut/FnMutPipe` implementation + impl_per_n.extend(impl_per_type_and_n( + n, + fn_pipe_names.1.clone(), + None, + Some(quote! { &mut }), + )); + + // create the `FnOnce/FnOncePipe` implementation; + // this implementation additionally needs to specify the associated `type Output` + let last = fn_type_names[n - 1].clone(); + impl_per_n.extend(impl_per_type_and_n( + n, + fn_pipe_names.2.clone(), + Some(quote! { + type Output = #last::Output; + }), + None, + )); + + impl_per_n + }; + + // we need to implement for all tuple sizes 1 through-to `n` + let mut impls = TokS2::new(); + for n in 1..=max_tuple_size { + impls.extend(impl_per_n(n)); + } + + // return all the impls + impls.into() +} + +#[ext] +impl [TokS2] { + #[allow(clippy::unwrap_used, clippy::single_call_fn)] + fn comma_separated(&self) -> TokS2 { + let comma_tok = ",".parse_unchecked(); + + // get the first token, and turn it into an accumulator + let mut toks = self.iter(); + let mut tok: TokS2 = toks.next().unwrap().clone(); + + // if there are more tokens to come, keep extending with comma + for next in toks { + tok.extend(comma_tok.clone()); + tok.extend(next.clone()); + } + + // return final comma-separated result + tok + } +} + +#[ext] +impl str { + fn parse_unchecked(&self) -> TokS2 { + match self.parse::() { + Ok(s) => s, + Err(e) => unimplemented!("{e}"), + } + } +} diff --git a/rust/util/fn_pipe/src/lib.rs b/rust/util/fn_pipe/src/lib.rs new file mode 100644 index 00000000..44dbc01d --- /dev/null +++ b/rust/util/fn_pipe/src/lib.rs @@ -0,0 +1,35 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation + +// enable Rust-unstable features for convenience +#![feature(tuple_trait)] +#![feature(unboxed_closures)] +#![feature(fn_traits)] +#![feature(unsized_fn_params)] // this is fine because I am PURELY wrapping around existing `Fn*` traits +// global lints +#![allow(internal_features)] +#![allow(clippy::arbitrary_source_item_ordering)] + +use fn_pipe_proc::impl_fn_pipe_for_tuple; +use std::marker::Tuple; + +/// A trait representing a pipe of functions, where the output of one will +/// be fed as the input of another, until the entire pipe ran +pub trait FnPipe: FnMutPipe { + extern "rust-call" fn run(&self, args: Args) -> Self::Output; +} + +pub trait FnMutPipe: FnOncePipe { + extern "rust-call" fn run_mut(&mut self, args: Args) -> Self::Output; +} + +pub trait FnOncePipe { + type Output; + + extern "rust-call" fn run_once(self, args: Args) -> Self::Output; +} + +// implement `Fn/Pipe*` variants for tuples of upto length 26, +// can be increased in the future +impl_fn_pipe_for_tuple!(26usize); diff --git a/rust/util/src/lib.rs b/rust/util/src/lib.rs new file mode 100644 index 00000000..5c34786c --- /dev/null +++ b/rust/util/src/lib.rs @@ -0,0 +1,53 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation +//! +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +#![feature(stmt_expr_attributes)] +#![feature(type_alias_impl_trait)] +#![feature(specialization)] +#![feature(unboxed_closures)] +#![feature(const_trait_impl)] +#![feature(fn_traits)] + +pub mod nonempty; + +pub(crate) mod private { + // sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias { +} + +/// Namespace for crate-wide extension traits/methods +pub mod ext { + use extend::ext; + + #[ext(pub, name = BoxedSliceExt)] + impl Box<[T]> { + #[inline] + fn map(self, f: F) -> Box<[B]> + where + F: FnMut(T) -> B, + { + self.into_iter().map(f).collect() + } + } + + #[ext(pub, name = VecExt)] + impl Vec { + #[inline] + fn map(self, f: F) -> Vec + where + F: FnMut(T) -> B, + { + self.into_iter().map(f).collect() + } + } +} diff --git a/rust/util/src/nonempty.rs b/rust/util/src/nonempty.rs new file mode 100644 index 00000000..acfcf971 --- /dev/null +++ b/rust/util/src/nonempty.rs @@ -0,0 +1,145 @@ +use fn_pipe::FnMutPipe; +use std::slice::SliceIndex; +use std::{ops, slice}; +use thiserror::Error; + +#[derive(Error, Debug)] +#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")] +pub struct EmptySliceError; + +/// A pointer to a non-empty fixed-size slice allocated on the heap. +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[repr(transparent)] +pub struct NonemptyArray(Box<[T]>); + +#[allow(clippy::arbitrary_source_item_ordering)] +impl NonemptyArray { + #[inline] + pub fn singleton(value: T) -> Self { + Self(Box::new([value])) + } + + #[allow(clippy::missing_errors_doc)] + #[inline] + pub fn try_from_boxed_slice>>( + boxed_slice: S, + ) -> Result { + let boxed_slice = boxed_slice.into(); + if boxed_slice.is_empty() { + Err(EmptySliceError) + } else { + Ok(Self(boxed_slice)) + } + } + + #[must_use] + #[inline] + pub fn into_boxed_slice(self) -> Box<[T]> { + self.0 + } + + #[must_use] + #[inline] + pub fn to_vec(&self) -> Vec + where + T: Clone, + { + self.0.to_vec() + } + + #[must_use] + #[inline] + pub const fn as_slice(&self) -> &[T] { + &self.0 + } + + #[allow(clippy::indexing_slicing)] + #[must_use] + #[inline] + pub fn first(&self) -> &T { + &self.0[0] + } + + #[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)] + #[must_use] + #[inline] + pub fn last(&self) -> &T { + &self.0[self.0.len() - 1] + } + + #[must_use] + #[inline] + pub fn get(&self, index: I) -> Option<&I::Output> + where + I: SliceIndex<[T]>, + { + self.0.get(index) + } + + #[allow(clippy::len_without_is_empty)] + #[must_use] + #[inline] + pub const fn len(&self) -> usize { + self.0.len() + } + + #[allow(clippy::iter_without_into_iter)] + #[inline] + pub fn iter(&self) -> slice::Iter<'_, T> { + self.0.iter() + } + + #[allow(clippy::iter_without_into_iter)] + #[inline] + pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> { + self.0.iter_mut() + } + + #[inline] + #[must_use] + pub fn map U>(self, f: F) -> NonemptyArray { + NonemptyArray(self.0.into_iter().map(f).collect()) + } + + #[inline] + #[must_use] + pub fn pipe U>(self, mut p: P) -> NonemptyArray { + self.map(|x| p.run_mut((x,))) + } +} + +impl From> for Box<[T]> { + #[inline] + fn from(value: NonemptyArray) -> Self { + value.into_boxed_slice() + } +} + +impl ops::Index for NonemptyArray { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} + +impl IntoIterator for NonemptyArray { + type Item = T; + type IntoIter = std::vec::IntoIter; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.into_boxed_slice().into_vec().into_iter() + } +} + +impl<'a, T> IntoIterator for &'a NonemptyArray { + type Item = &'a T; + type IntoIter = slice::Iter<'a, T>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/shared/db/sqlite/connector.py b/shared/db/sqlite/connector.py index bdf34948..2009c8c0 100644 --- a/shared/db/sqlite/connector.py +++ b/shared/db/sqlite/connector.py @@ -12,9 +12,8 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlmodel import SQLModel -from shared.types.events.common import NodeId +from shared.types.events import Event, EventParser, NodeId from shared.types.events.components import EventFromEventLog -from shared.types.events.registry import Event, EventParser from .types import StoredEvent diff --git a/shared/db/sqlite/types.py b/shared/db/sqlite/types.py index 880de7b3..262fe4a7 100644 --- a/shared/db/sqlite/types.py +++ b/shared/db/sqlite/types.py @@ -5,8 +5,8 @@ from sqlalchemy import DateTime, Index from sqlmodel import JSON, Column, Field, SQLModel from shared.types.common import NodeId +from shared.types.events import Event from shared.types.events.components import EventFromEventLog -from shared.types.events.registry import Event class StoredEvent(SQLModel, table=True): diff --git a/shared/event_loops/main.py b/shared/event_loops/main.py index d481b3f4..582745e6 100644 --- a/shared/event_loops/main.py +++ b/shared/event_loops/main.py @@ -7,8 +7,8 @@ from typing import Any, Hashable, Mapping, Protocol, Sequence from fastapi.responses import Response, StreamingResponse from shared.event_loops.commands import ExternalCommand +from shared.types.events import Event from shared.types.events.components import Apply, EventFromEventLog -from shared.types.events.registry import Event from shared.types.state import State diff --git a/shared/pyproject.toml b/shared/pyproject.toml index 95a78f5c..c4c5adeb 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "rustworkx>=0.16.0", "sqlmodel>=0.0.22", "sqlalchemy[asyncio]>=2.0.0", + "greenlet>=3.2.3" ] [build-system] @@ -37,4 +38,6 @@ exclude = ["protobufs/schemas", "*.md", "pyproject.toml"] [dependency-groups] dev = [ "types-protobuf>=6.30.2.20250516", + "pytest>=8.4.0", + "pytest-asyncio>=1.0.0", ] diff --git a/shared/tests/test_sqlite_connector.py b/shared/tests/test_sqlite_connector.py index 32e9ea8c..50fef7ad 100644 --- a/shared/tests/test_sqlite_connector.py +++ b/shared/tests/test_sqlite_connector.py @@ -11,11 +11,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig from shared.types.common import NodeId -from shared.types.events.chunks import ChunkType, TokenChunk -from shared.types.events.events import ( +from shared.types.events import ( ChunkGenerated, - EventType, + _EventType, ) +from shared.types.events.chunks import ChunkType, TokenChunk from shared.types.request import RequestId # Type ignore comment for all protected member access in this test file @@ -472,7 +472,7 @@ class TestAsyncSQLiteEventStorage: # Verify the event was deserialized correctly retrieved_event = retrieved_event_wrapper.event assert isinstance(retrieved_event, ChunkGenerated) - assert retrieved_event.event_type == EventType.ChunkGenerated + assert retrieved_event.event_type == _EventType.ChunkGenerated assert retrieved_event.request_id == request_id # Verify the nested chunk was deserialized correctly diff --git a/shared/types/events/__init__.py b/shared/types/events/__init__.py new file mode 100644 index 00000000..db6adbd5 --- /dev/null +++ b/shared/types/events/__init__.py @@ -0,0 +1,99 @@ +# ruff: noqa: F403 +# ruff: noqa: F405 + +import types +import typing +from typing import Annotated, Union + +# Note: we are implementing internal details here, so importing private stuff is fine!!! +from pydantic import Field, TypeAdapter + +from ...constants import get_error_reporting_message +from ._common import * +from ._common import _BaseEvent, _EventType # pyright: ignore[reportPrivateUsage] +from ._events import * + +_Event = Union[ + TaskCreated, + TaskStateUpdated, + TaskDeleted, + InstanceCreated, + InstanceActivated, + InstanceDeactivated, + InstanceDeleted, + InstanceReplacedAtomically, + RunnerStatusUpdated, + NodePerformanceMeasured, + WorkerConnected, + WorkerStatusUpdated, + WorkerDisconnected, + ChunkGenerated, + TopologyEdgeCreated, + TopologyEdgeReplacedAtomically, + TopologyEdgeDeleted, + MLXInferenceSagaPrepare, + MLXInferenceSagaStartPrepare, +] +""" +Un-annotated union of all events. Only used internally to create the registry. +For all other usecases, use the annotated union of events :class:`Event` :) +""" + +Event = Annotated[_Event, Field(discriminator="event_type")] +"""Type of events, a discriminated union.""" + +EventParser: TypeAdapter[Event] = TypeAdapter(Event) +"""Type adaptor to parse :class:`Event`s.""" + + +def _check_event_type_consistency(): + # Grab enum values from members + member_enum_values = [m for m in _EventType] + + # grab enum values from the union => scrape the type annotation + union_enum_values: list[_EventType] = [] + union_classes = list(typing.get_args(_Event)) + for cls in union_classes: # pyright: ignore[reportAny] + assert issubclass(cls, object), ( + f"{get_error_reporting_message()}", + f"The class {cls} is NOT a subclass of {object}." + ) + + # ensure the first base parameter is ALWAYS _BaseEvent + base_cls = list(types.get_original_bases(cls)) + assert len(base_cls) >= 1 and issubclass(base_cls[0], object) \ + and issubclass(base_cls[0], _BaseEvent), ( + f"{get_error_reporting_message()}", + f"The class {cls} does NOT inherit from {_BaseEvent} {typing.get_origin(base_cls[0])}." + ) + + # grab type hints and extract the right values from it + cls_hints = typing.get_type_hints(cls) + assert "event_type" in cls_hints and \ + typing.get_origin(cls_hints["event_type"]) is typing.Literal, ( # pyright: ignore[reportAny] + f"{get_error_reporting_message()}", + f"The class {cls} is missing a {typing.Literal}-annotated `event_type` field." + ) + + # make sure the value is an instance of `_EventType` + enum_value = list(typing.get_args(cls_hints["event_type"])) + assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), ( + f"{get_error_reporting_message()}", + f"The `event_type` of {cls} has a non-{_EventType} literal-type." + ) + union_enum_values.append(enum_value[0]) + + # ensure there is a 1:1 bijection between the two + for m in member_enum_values: + assert m in union_enum_values, ( + f"{get_error_reporting_message()}", + f"There is no event-type registered for {m} in {_Event}." + ) + union_enum_values.remove(m) + assert len(union_enum_values) == 0, ( + f"{get_error_reporting_message()}", + f"The following events have multiple event types defined in {_Event}: {union_enum_values}." + ) + + +_check_event_type_consistency() diff --git a/shared/types/events/common.py b/shared/types/events/_common.py similarity index 67% rename from shared/types/events/common.py rename to shared/types/events/_common.py index f19f17a4..72788da1 100644 --- a/shared/types/events/common.py +++ b/shared/types/events/_common.py @@ -1,9 +1,5 @@ from enum import Enum -from typing import ( - TYPE_CHECKING, - Generic, - TypeVar, -) +from typing import TYPE_CHECKING if TYPE_CHECKING: pass @@ -14,40 +10,44 @@ from shared.types.common import NewUUID, NodeId class EventId(NewUUID): - pass + """ + Newtype around `NewUUID` + """ -class TimerId(NewUUID): - pass +# Event base-class boilerplate (you should basically never touch these) +# Only very specialised registry or serialisation/deserialization logic might need know about these +class _EventType(str, Enum): + """ + Here are all the unique kinds of events that can be sent over the network. + """ -# Here are all the unique kinds of events that can be sent over the network. -class EventType(str, Enum): # Task Saga Events MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare" MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare" - + # Task Events TaskCreated = "TaskCreated" TaskStateUpdated = "TaskStateUpdated" TaskDeleted = "TaskDeleted" - + # Streaming Events ChunkGenerated = "ChunkGenerated" - + # Instance Events InstanceCreated = "InstanceCreated" InstanceDeleted = "InstanceDeleted" InstanceActivated = "InstanceActivated" InstanceDeactivated = "InstanceDeactivated" InstanceReplacedAtomically = "InstanceReplacedAtomically" - + # Runner Status Events RunnerStatusUpdated = "RunnerStatusUpdated" - + # Node Performance Events NodePerformanceMeasured = "NodePerformanceMeasured" - + # Topology Events TopologyEdgeCreated = "TopologyEdgeCreated" TopologyEdgeReplacedAtomically = "TopologyEdgeReplacedAtomically" @@ -55,25 +55,26 @@ class EventType(str, Enum): WorkerConnected = "WorkerConnected" WorkerStatusUpdated = "WorkerStatusUpdated" WorkerDisconnected = "WorkerDisconnected" - - # Timer Events - TimerCreated = "TimerCreated" - TimerFired = "TimerFired" -EventTypeT = TypeVar("EventTypeT", bound=EventType) + # # Timer Events + # TimerCreated = "TimerCreated" + # TimerFired = "TimerFired" -class BaseEvent(BaseModel, Generic[EventTypeT]): - event_type: EventTypeT +class _BaseEvent[T: _EventType](BaseModel): # pyright: ignore[reportUnusedClass] + """ + This is the event base-class, to please the Pydantic gods. + PLEASE don't use this for anything unless you know why you are doing so, + instead just use the events union :) + """ + + event_type: T event_id: EventId = EventId() def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: """Check if the event was sent by the correct node. - + This is a placeholder implementation that always returns True. Subclasses can override this method to implement specific validation logic. """ return True - - - diff --git a/shared/types/events/_events.py b/shared/types/events/_events.py new file mode 100644 index 00000000..0c3a80f7 --- /dev/null +++ b/shared/types/events/_events.py @@ -0,0 +1,132 @@ +from typing import Literal + +from shared.topology import Connection, ConnectionProfile, Node, NodePerformanceProfile +from shared.types.common import NodeId +from shared.types.events.chunks import GenerationChunk +from shared.types.request import RequestId +from shared.types.tasks import Task, TaskId, TaskStatus +from shared.types.worker.common import InstanceId, NodeStatus +from shared.types.worker.instances import InstanceParams, TypeOfInstance +from shared.types.worker.runners import RunnerId, RunnerStatus + +from ._common import _BaseEvent, _EventType # pyright: ignore[reportPrivateUsage] + + +class TaskCreated(_BaseEvent[_EventType.TaskCreated]): + event_type: Literal[_EventType.TaskCreated] = _EventType.TaskCreated + task_id: TaskId + task: Task + + +class TaskDeleted(_BaseEvent[_EventType.TaskDeleted]): + event_type: Literal[_EventType.TaskDeleted] = _EventType.TaskDeleted + task_id: TaskId + + +class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]): + event_type: Literal[_EventType.TaskStateUpdated] = _EventType.TaskStateUpdated + task_id: TaskId + task_status: TaskStatus + + +class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]): + event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated + instance_id: InstanceId + instance_params: InstanceParams + instance_type: TypeOfInstance + + +class InstanceActivated(_BaseEvent[_EventType.InstanceActivated]): + event_type: Literal[_EventType.InstanceActivated] = _EventType.InstanceActivated + instance_id: InstanceId + + +class InstanceDeactivated(_BaseEvent[_EventType.InstanceDeactivated]): + event_type: Literal[_EventType.InstanceDeactivated] = _EventType.InstanceDeactivated + instance_id: InstanceId + + +class InstanceDeleted(_BaseEvent[_EventType.InstanceDeleted]): + event_type: Literal[_EventType.InstanceDeleted] = _EventType.InstanceDeleted + instance_id: InstanceId + + transition: tuple[InstanceId, InstanceId] + + +class InstanceReplacedAtomically(_BaseEvent[_EventType.InstanceReplacedAtomically]): + event_type: Literal[_EventType.InstanceReplacedAtomically] = _EventType.InstanceReplacedAtomically + instance_to_replace: InstanceId + new_instance_id: InstanceId + + +class RunnerStatusUpdated(_BaseEvent[_EventType.RunnerStatusUpdated]): + event_type: Literal[_EventType.RunnerStatusUpdated] = _EventType.RunnerStatusUpdated + runner_id: RunnerId + runner_status: RunnerStatus + + +class MLXInferenceSagaPrepare(_BaseEvent[_EventType.MLXInferenceSagaPrepare]): + event_type: Literal[_EventType.MLXInferenceSagaPrepare] = _EventType.MLXInferenceSagaPrepare + task_id: TaskId + instance_id: InstanceId + + +class MLXInferenceSagaStartPrepare(_BaseEvent[_EventType.MLXInferenceSagaStartPrepare]): + event_type: Literal[_EventType.MLXInferenceSagaStartPrepare] = _EventType.MLXInferenceSagaStartPrepare + task_id: TaskId + instance_id: InstanceId + + +class NodePerformanceMeasured(_BaseEvent[_EventType.NodePerformanceMeasured]): + event_type: Literal[_EventType.NodePerformanceMeasured] = _EventType.NodePerformanceMeasured + node_id: NodeId + node_profile: NodePerformanceProfile + + +class WorkerConnected(_BaseEvent[_EventType.WorkerConnected]): + event_type: Literal[_EventType.WorkerConnected] = _EventType.WorkerConnected + edge: Connection + + +class WorkerStatusUpdated(_BaseEvent[_EventType.WorkerStatusUpdated]): + event_type: Literal[_EventType.WorkerStatusUpdated] = _EventType.WorkerStatusUpdated + node_id: NodeId + node_state: NodeStatus + + +class WorkerDisconnected(_BaseEvent[_EventType.WorkerDisconnected]): + event_type: Literal[_EventType.WorkerDisconnected] = _EventType.WorkerDisconnected + vertex_id: NodeId + + +class ChunkGenerated(_BaseEvent[_EventType.ChunkGenerated]): + event_type: Literal[_EventType.ChunkGenerated] = _EventType.ChunkGenerated + request_id: RequestId + chunk: GenerationChunk + + +class TopologyEdgeCreated(_BaseEvent[_EventType.TopologyEdgeCreated]): + event_type: Literal[_EventType.TopologyEdgeCreated] = _EventType.TopologyEdgeCreated + vertex: Node + + +class TopologyEdgeReplacedAtomically(_BaseEvent[_EventType.TopologyEdgeReplacedAtomically]): + event_type: Literal[_EventType.TopologyEdgeReplacedAtomically] = _EventType.TopologyEdgeReplacedAtomically + edge: Connection + edge_profile: ConnectionProfile + + +class TopologyEdgeDeleted(_BaseEvent[_EventType.TopologyEdgeDeleted]): + event_type: Literal[_EventType.TopologyEdgeDeleted] = _EventType.TopologyEdgeDeleted + edge: Connection + + +# class TimerCreated(_BaseEvent[_EventType.TimerCreated]): +# event_type: Literal[_EventType.TimerCreated] = _EventType.TimerCreated +# timer_id: TimerId +# delay_seconds: float +# +# +# class TimerFired(_BaseEvent[_EventType.TimerFired]): +# event_type: Literal[_EventType.TimerFired] = _EventType.TimerFired +# timer_id: TimerId \ No newline at end of file diff --git a/shared/types/events/categories.py b/shared/types/events/categories.py index 0059348c..3954af21 100644 --- a/shared/types/events/categories.py +++ b/shared/types/events/categories.py @@ -1,10 +1,9 @@ - -from shared.types.events.events import ( +from . import ( MLXInferenceSagaPrepare, MLXInferenceSagaStartPrepare, ) TaskSagaEvent = ( - MLXInferenceSagaPrepare - | MLXInferenceSagaStartPrepare -) \ No newline at end of file + MLXInferenceSagaPrepare + | MLXInferenceSagaStartPrepare +) diff --git a/shared/types/events/commands.py b/shared/types/events/commands.py index 9d7cd1ff..6651d823 100644 --- a/shared/types/events/commands.py +++ b/shared/types/events/commands.py @@ -11,9 +11,10 @@ if TYPE_CHECKING: from pydantic import BaseModel from shared.types.common import NewUUID -from shared.types.events.registry import Event from shared.types.state import State +from . import Event + class CommandId(NewUUID): pass diff --git a/shared/types/events/components.py b/shared/types/events/components.py index 2f6d5087..0a676ae8 100644 --- a/shared/types/events/components.py +++ b/shared/types/events/components.py @@ -13,7 +13,7 @@ from typing import Callable from pydantic import BaseModel, Field, model_validator from shared.types.common import NodeId -from shared.types.events.registry import Event +from shared.types.events import Event from shared.types.state import State diff --git a/shared/types/events/events.py b/shared/types/events/events.py deleted file mode 100644 index 90c98a27..00000000 --- a/shared/types/events/events.py +++ /dev/null @@ -1,137 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Tuple - -from shared.topology import Connection, ConnectionProfile, Node, NodePerformanceProfile -from shared.types.common import NodeId -from shared.types.events.chunks import GenerationChunk -from shared.types.events.common import ( - BaseEvent, - EventType, - TimerId, -) -from shared.types.request import RequestId -from shared.types.tasks import Task, TaskId, TaskStatus -from shared.types.worker.common import InstanceId, NodeStatus -from shared.types.worker.instances import InstanceParams, TypeOfInstance -from shared.types.worker.runners import RunnerId, RunnerStatus - - -class TaskCreated(BaseEvent[EventType.TaskCreated]): - event_type: Literal[EventType.TaskCreated] = EventType.TaskCreated - task_id: TaskId - task: Task - - -class TaskDeleted(BaseEvent[EventType.TaskDeleted]): - event_type: Literal[EventType.TaskDeleted] = EventType.TaskDeleted - task_id: TaskId - - -class TaskStateUpdated(BaseEvent[EventType.TaskStateUpdated]): - event_type: Literal[EventType.TaskStateUpdated] = EventType.TaskStateUpdated - task_id: TaskId - task_status: TaskStatus - - -class InstanceCreated(BaseEvent[EventType.InstanceCreated]): - event_type: Literal[EventType.InstanceCreated] = EventType.InstanceCreated - instance_id: InstanceId - instance_params: InstanceParams - instance_type: TypeOfInstance - - -class InstanceActivated(BaseEvent[EventType.InstanceActivated]): - event_type: Literal[EventType.InstanceActivated] = EventType.InstanceActivated - instance_id: InstanceId - - -class InstanceDeactivated(BaseEvent[EventType.InstanceDeactivated]): - event_type: Literal[EventType.InstanceDeactivated] = EventType.InstanceDeactivated - instance_id: InstanceId - - -class InstanceDeleted(BaseEvent[EventType.InstanceDeleted]): - event_type: Literal[EventType.InstanceDeleted] = EventType.InstanceDeleted - instance_id: InstanceId - - transition: Tuple[InstanceId, InstanceId] - - -class InstanceReplacedAtomically(BaseEvent[EventType.InstanceReplacedAtomically]): - event_type: Literal[EventType.InstanceReplacedAtomically] = EventType.InstanceReplacedAtomically - instance_to_replace: InstanceId - new_instance_id: InstanceId - - -class RunnerStatusUpdated(BaseEvent[EventType.RunnerStatusUpdated]): - event_type: Literal[EventType.RunnerStatusUpdated] = EventType.RunnerStatusUpdated - runner_id: RunnerId - runner_status: RunnerStatus - - -class MLXInferenceSagaPrepare(BaseEvent[EventType.MLXInferenceSagaPrepare]): - event_type: Literal[EventType.MLXInferenceSagaPrepare] = EventType.MLXInferenceSagaPrepare - task_id: TaskId - instance_id: InstanceId - - -class MLXInferenceSagaStartPrepare(BaseEvent[EventType.MLXInferenceSagaStartPrepare]): - event_type: Literal[EventType.MLXInferenceSagaStartPrepare] = EventType.MLXInferenceSagaStartPrepare - task_id: TaskId - instance_id: InstanceId - - -class NodePerformanceMeasured(BaseEvent[EventType.NodePerformanceMeasured]): - event_type: Literal[EventType.NodePerformanceMeasured] = EventType.NodePerformanceMeasured - node_id: NodeId - node_profile: NodePerformanceProfile - - -class WorkerConnected(BaseEvent[EventType.WorkerConnected]): - event_type: Literal[EventType.WorkerConnected] = EventType.WorkerConnected - edge: Connection - - -class WorkerStatusUpdated(BaseEvent[EventType.WorkerStatusUpdated]): - event_type: Literal[EventType.WorkerStatusUpdated] = EventType.WorkerStatusUpdated - node_id: NodeId - node_state: NodeStatus - - -class WorkerDisconnected(BaseEvent[EventType.WorkerDisconnected]): - event_type: Literal[EventType.WorkerDisconnected] = EventType.WorkerDisconnected - vertex_id: NodeId - - -class ChunkGenerated(BaseEvent[EventType.ChunkGenerated]): - event_type: Literal[EventType.ChunkGenerated] = EventType.ChunkGenerated - request_id: RequestId - chunk: GenerationChunk - - -class TopologyEdgeCreated(BaseEvent[EventType.TopologyEdgeCreated]): - event_type: Literal[EventType.TopologyEdgeCreated] = EventType.TopologyEdgeCreated - vertex: Node - - -class TopologyEdgeReplacedAtomically(BaseEvent[EventType.TopologyEdgeReplacedAtomically]): - event_type: Literal[EventType.TopologyEdgeReplacedAtomically] = EventType.TopologyEdgeReplacedAtomically - edge: Connection - edge_profile: ConnectionProfile - - -class TopologyEdgeDeleted(BaseEvent[EventType.TopologyEdgeDeleted]): - event_type: Literal[EventType.TopologyEdgeDeleted] = EventType.TopologyEdgeDeleted - edge: Connection - - -class TimerCreated(BaseEvent[EventType.TimerCreated]): - event_type: Literal[EventType.TimerCreated] = EventType.TimerCreated - timer_id: TimerId - delay_seconds: float - - -class TimerFired(BaseEvent[EventType.TimerFired]): - event_type: Literal[EventType.TimerFired] = EventType.TimerFired - timer_id: TimerId \ No newline at end of file diff --git a/shared/types/events/registry.py b/shared/types/events/registry.py deleted file mode 100644 index 959ada0f..00000000 --- a/shared/types/events/registry.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Annotated, Any, Mapping, Type, TypeAlias - -from pydantic import Field, TypeAdapter - -from shared.types.events.common import ( - EventType, -) -from shared.types.events.events import ( - ChunkGenerated, - InstanceActivated, - InstanceCreated, - InstanceDeactivated, - InstanceDeleted, - InstanceReplacedAtomically, - MLXInferenceSagaPrepare, - MLXInferenceSagaStartPrepare, - NodePerformanceMeasured, - RunnerStatusUpdated, - TaskCreated, - TaskDeleted, - TaskStateUpdated, - TimerCreated, - TimerFired, - TopologyEdgeCreated, - TopologyEdgeDeleted, - TopologyEdgeReplacedAtomically, - WorkerConnected, - WorkerDisconnected, - WorkerStatusUpdated, -) -from shared.types.events.sanity_checking import ( - assert_event_union_covers_registry, - check_registry_has_all_event_types, - check_union_of_all_events_is_consistent_with_registry, -) - -""" -class EventTypeNames(StrEnum): - TaskEventType = auto() - InstanceEvent = auto() - NodePerformanceEvent = auto() - ControlPlaneEvent = auto() - StreamingEvent = auto() - DataPlaneEvent = auto() - TimerEvent = auto() - MLXEvent = auto() - -check_event_categories_are_defined_for_all_event_types(EVENT_TYPE_ENUMS, EventTypeNames) -""" -EventRegistry: Mapping[EventType, Type[Any]] = { - EventType.TaskCreated: TaskCreated, - EventType.TaskStateUpdated: TaskStateUpdated, - EventType.TaskDeleted: TaskDeleted, - EventType.InstanceCreated: InstanceCreated, - EventType.InstanceActivated: InstanceActivated, - EventType.InstanceDeactivated: InstanceDeactivated, - EventType.InstanceDeleted: InstanceDeleted, - EventType.InstanceReplacedAtomically: InstanceReplacedAtomically, - EventType.RunnerStatusUpdated: RunnerStatusUpdated, - EventType.NodePerformanceMeasured: NodePerformanceMeasured, - EventType.WorkerConnected: WorkerConnected, - EventType.WorkerStatusUpdated: WorkerStatusUpdated, - EventType.WorkerDisconnected: WorkerDisconnected, - EventType.ChunkGenerated: ChunkGenerated, - EventType.TopologyEdgeCreated: TopologyEdgeCreated, - EventType.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically, - EventType.TopologyEdgeDeleted: TopologyEdgeDeleted, - EventType.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare, - EventType.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare, - EventType.TimerCreated: TimerCreated, - EventType.TimerFired: TimerFired, -} - - -AllEventsUnion = ( - TaskCreated - | TaskStateUpdated - | TaskDeleted - | InstanceCreated - | InstanceActivated - | InstanceDeactivated - | InstanceDeleted - | InstanceReplacedAtomically - | RunnerStatusUpdated - | NodePerformanceMeasured - | WorkerConnected - | WorkerStatusUpdated - | WorkerDisconnected - | ChunkGenerated - | TopologyEdgeCreated - | TopologyEdgeReplacedAtomically - | TopologyEdgeDeleted - | MLXInferenceSagaPrepare - | MLXInferenceSagaStartPrepare - | TimerCreated - | TimerFired -) - -Event: TypeAlias = Annotated[AllEventsUnion, Field(discriminator="event_type")] -EventParser: TypeAdapter[Event] = TypeAdapter(Event) - - - - -assert_event_union_covers_registry(AllEventsUnion) -check_union_of_all_events_is_consistent_with_registry(EventRegistry, AllEventsUnion) -check_registry_has_all_event_types(EventRegistry) \ No newline at end of file diff --git a/shared/types/events/sanity_checking.py b/shared/types/events/sanity_checking.py deleted file mode 100644 index def11557..00000000 --- a/shared/types/events/sanity_checking.py +++ /dev/null @@ -1,75 +0,0 @@ -from enum import StrEnum -from types import UnionType -from typing import Any, Mapping, Set, Type, cast, get_args - -from pydantic.fields import FieldInfo - -from shared.constants import get_error_reporting_message -from shared.types.events.common import EventType - - -def assert_event_union_covers_registry[TEnum: StrEnum]( - literal_union: UnionType, -) -> None: - """ - Ensure that our union of events (AllEventsUnion) has one member per element of Enum - """ - enum_values: Set[str] = {member.value for member in EventType} - - def _flatten(tp: UnionType) -> Set[str]: - values: Set[str] = set() - args = get_args(tp) # Get event classes from the union - for arg in args: # type: ignore[reportAny] - # Cast to type since we know these are class types - event_class = cast(type[Any], arg) - # Each event class is a Pydantic model with model_fields - if hasattr(event_class, 'model_fields'): - model_fields = cast(dict[str, FieldInfo], event_class.model_fields) - if 'event_type' in model_fields: - # Get the default value of the event_type field - event_type_field: FieldInfo = model_fields['event_type'] - if hasattr(event_type_field, 'default'): - default_value = cast(EventType, event_type_field.default) - # The default is an EventType enum member, get its value - values.add(default_value.value) - return values - - literal_values: Set[str] = _flatten(literal_union) - - assert enum_values == literal_values, ( - f"{get_error_reporting_message()}" - f"The values of the enum {EventType} are not covered by the literal union {literal_union}.\n" - f"These are the missing values: {enum_values - literal_values}\n" - f"These are the extra values: {literal_values - enum_values}\n" - ) - -def check_union_of_all_events_is_consistent_with_registry( - registry: Mapping[EventType, Type[Any]], union_type: UnionType -) -> None: - type_of_each_registry_entry = set(registry.values()) - type_of_each_entry_in_union = set(get_args(union_type)) - missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union - - assert not missing_from_union, ( - f"{get_error_reporting_message()}" - f"Event classes in registry are missing from all_events union: {missing_from_union}" - ) - - extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry - - assert not extra_in_union, ( - f"{get_error_reporting_message()}" - f"Event classes in all_events union are missing from registry: {extra_in_union}" - ) - -def check_registry_has_all_event_types(event_registry: Mapping[EventType, Type[Any]]) -> None: - event_types: tuple[EventType, ...] = get_args(EventType) - missing_event_types = set(event_types) - set(event_registry.keys()) - - assert not missing_event_types, ( - f"{get_error_reporting_message()}" - f"There's an event missing from the registry: {missing_event_types}" - ) - -# TODO: Check all events have an apply function. -# probably in a different place though. \ No newline at end of file diff --git a/shared/types/worker/ops.py b/shared/types/worker/ops.py index f956a32c..fb4a7521 100644 --- a/shared/types/worker/ops.py +++ b/shared/types/worker/ops.py @@ -3,7 +3,7 @@ from typing import Annotated, Generic, Literal, TypeVar, Union from pydantic import BaseModel, Field -from shared.types.events.events import InstanceId +from shared.types.events import InstanceId from shared.types.tasks import Task from shared.types.worker.common import RunnerId from shared.types.worker.mlx import Host diff --git a/throwaway_tests/segfault_multiprocess.py b/throwaway_tests/segfault_multiprocess.py new file mode 100644 index 00000000..28c835f6 --- /dev/null +++ b/throwaway_tests/segfault_multiprocess.py @@ -0,0 +1,31 @@ +import ctypes; +from multiprocessing import Process + +def trigger_segfault(): + ctypes.string_at(0) + +def subprocess_main(id: int): + print(f"SUBPROCESS {id}: PROCESS START") + trigger_segfault() + print(f"SUBPROCESS {id}: PROCESS END") + +def main(): + """This code tests that a master process is not brought down by + segfaults that occur in the child processes + """ + + print("MASTER: PROCESS START") + procs: list[Process] = [] + for i in range(0, 10): + p = Process(target=subprocess_main, args=(i,)) + procs.append(p) + p.start() + + print("MASTER: JOINING SUBPROCESSES") + for p in procs: + p.join() + + print("MASTER: PROCESS END") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/uv.lock b/uv.lock index d1fc02fc..3c541f99 100644 --- a/uv.lock +++ b/uv.lock @@ -15,7 +15,7 @@ members = [ "exo", "exo-engine-mlx", "exo-master", - "exo-networking", + "exo-pyo3-bindings", "exo-shared", "exo-worker", ] @@ -181,6 +181,8 @@ dependencies = [ { name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "exo-master", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "exo-worker", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typeguard", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] @@ -204,6 +206,8 @@ requires-dist = [ { name = "exo-master", editable = "master" }, { name = "exo-worker", editable = "worker" }, { name = "mlx", marker = "extra == 'darwin'" }, + { name = "pydantic", specifier = ">=2.11.7" }, + { name = "typeguard", specifier = ">=4.4.4" }, { name = "types-aiofiles", specifier = ">=24.1.0.20250708" }, ] provides-extras = ["darwin"] @@ -239,9 +243,25 @@ requires-dist = [ ] [[package]] -name = "exo-networking" +name = "exo-pyo3-bindings" version = "0.1.0" -source = { editable = "networking/topology" } +source = { editable = "rust/exo_pyo3_bindings" } + +[package.dev-dependencies] +dev = [ + { name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pytest-asyncio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] + +[package.metadata] + +[package.metadata.requires-dev] +dev = [ + { name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" }, + { name = "pytest", specifier = ">=8.4.0" }, + { name = "pytest-asyncio", specifier = ">=1.0.0" }, +] [[package]] name = "exo-shared" @@ -249,6 +269,7 @@ version = "0.1.0" source = { editable = "shared" } dependencies = [ { name = "aiosqlite", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pathlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -262,12 +283,15 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pytest-asyncio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "types-protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] [package.metadata] requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, + { name = "greenlet", specifier = ">=3.2.3" }, { name = "networkx", specifier = ">=3.5" }, { name = "openai", specifier = ">=1.93.0" }, { name = "pathlib", specifier = ">=1.0.1" }, @@ -280,7 +304,11 @@ requires-dist = [ ] [package.metadata.requires-dev] -dev = [{ name = "types-protobuf", specifier = ">=6.30.2.20250516" }] +dev = [ + { name = "pytest", specifier = ">=8.4.0" }, + { name = "pytest-asyncio", specifier = ">=1.0.0" }, + { name = "types-protobuf", specifier = ">=6.30.2.20250516" }, +] [[package]] name = "exo-worker" @@ -288,6 +316,7 @@ version = "0.1.0" source = { editable = "worker" } dependencies = [ { name = "exo-shared", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] @@ -295,6 +324,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "exo-shared", editable = "shared" }, + { name = "huggingface-hub", specifier = ">=0.33.4" }, { name = "mlx", specifier = "==0.26.3" }, { name = "mlx-lm", specifier = ">=0.25.3" }, ] @@ -688,7 +718,7 @@ wheels = [ [[package]] name = "openai" -version = "1.97.0" +version = "1.97.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -700,9 +730,9 @@ dependencies = [ { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e0/c6/b8d66e4f3b95493a8957065b24533333c927dc23817abe397f13fe589c6e/openai-1.97.0.tar.gz", hash = "sha256:0be349569ccaa4fb54f97bb808423fd29ccaeb1246ee1be762e0c81a47bae0aa", size = 493850, upload-time = "2025-07-16T16:37:35.196Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/57/1c471f6b3efb879d26686d31582997615e969f3bb4458111c9705e56332e/openai-1.97.1.tar.gz", hash = "sha256:a744b27ae624e3d4135225da9b1c89c107a2a7e5bc4c93e5b7b5214772ce7a4e", size = 494267, upload-time = "2025-07-22T13:10:12.607Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/91/1f1cf577f745e956b276a8b1d3d76fa7a6ee0c2b05db3b001b900f2c71db/openai-1.97.0-py3-none-any.whl", hash = "sha256:a1c24d96f4609f3f7f51c9e1c2606d97cc6e334833438659cfd687e9c972c610", size = 764953, upload-time = "2025-07-16T16:37:33.135Z" }, + { url = "https://files.pythonhosted.org/packages/ee/35/412a0e9c3f0d37c94ed764b8ac7adae2d834dbd20e69f6aca582118e0f55/openai-1.97.1-py3-none-any.whl", hash = "sha256:4e96bbdf672ec3d44968c9ea39d2c375891db1acc1794668d8149d5fa6000606", size = 764380, upload-time = "2025-07-22T13:10:10.689Z" }, ] [[package]] @@ -1092,6 +1122,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/b1/d7520cc5cb69c825599042eb3a7c986fa9baa8a8d2dea9acd78e152c81e2/transformers-4.53.3-py3-none-any.whl", hash = "sha256:5aba81c92095806b6baf12df35d756cf23b66c356975fb2a7fa9e536138d7c75", size = 10826382, upload-time = "2025-07-22T07:30:48.458Z" }, ] +[[package]] +name = "typeguard" +version = "4.4.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/68/71c1a15b5f65f40e91b65da23b8224dad41349894535a97f63a52e462196/typeguard-4.4.4.tar.gz", hash = "sha256:3a7fd2dffb705d4d0efaed4306a704c89b9dee850b688f060a8b1615a79e5f74", size = 75203, upload-time = "2025-06-18T09:56:07.624Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1b/a9/e3aee762739c1d7528da1c3e06d518503f8b6c439c35549b53735ba52ead/typeguard-4.4.4-py3-none-any.whl", hash = "sha256:b5f562281b6bfa1f5492470464730ef001646128b180769880468bd84b68b09e", size = 34874, upload-time = "2025-06-18T09:56:05.999Z" }, +] + [[package]] name = "types-aiofiles" version = "24.1.0.20250708" diff --git a/worker/main.py b/worker/main.py index 9bb6121e..5c73512f 100644 --- a/worker/main.py +++ b/worker/main.py @@ -8,8 +8,7 @@ from typing import AsyncGenerator, Optional from pydantic import BaseModel, ConfigDict from shared.types.common import NodeId -from shared.types.events.events import ChunkGenerated, InstanceId, RunnerStatusUpdated -from shared.types.events.registry import Event +from shared.types.events import ChunkGenerated, Event, InstanceId, RunnerStatusUpdated from shared.types.state import State from shared.types.worker.common import RunnerId from shared.types.worker.downloads import ( diff --git a/worker/pyproject.toml b/worker/pyproject.toml index 49ede7b7..b2e1a330 100644 --- a/worker/pyproject.toml +++ b/worker/pyproject.toml @@ -6,8 +6,10 @@ readme = "README.md" requires-python = ">=3.13" dependencies = [ "exo-shared", + "huggingface_hub>=0.33.4", "mlx==0.26.3", "mlx-lm>=0.25.3", + ] [build-system] diff --git a/worker/tests/test_worker_handlers.py b/worker/tests/test_worker_handlers.py index 0812622c..e1a01ca3 100644 --- a/worker/tests/test_worker_handlers.py +++ b/worker/tests/test_worker_handlers.py @@ -7,9 +7,8 @@ from typing import Callable import pytest from shared.types.common import NodeId +from shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated from shared.types.events.chunks import TokenChunk -from shared.types.events.events import ChunkGenerated, RunnerStatusUpdated -from shared.types.events.registry import Event from shared.types.tasks import Task from shared.types.worker.common import RunnerId from shared.types.worker.instances import Instance