mirror of
https://github.com/exo-explore/exo.git
synced 2026-06-02 19:27:55 -04:00
libp2p -> zenoh
This commit is contained in:
10
AGENTS.md
10
AGENTS.md
@@ -4,7 +4,7 @@ This file provides guidance to AI coding agents when working with code in this r
|
||||
|
||||
## Project Overview
|
||||
|
||||
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
|
||||
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and zenoh for peer-to-peer networking.
|
||||
|
||||
## Build & Run Commands
|
||||
|
||||
@@ -69,7 +69,7 @@ If `nix fmt` changes any files, stage them before committing. The CI runs `nix f
|
||||
|
||||
### Node Composition
|
||||
A single exo `Node` (src/exo/main.py) runs multiple components:
|
||||
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
|
||||
- **Router**: zenoh-based pub/sub messaging via Rust bindings (exo_rs)
|
||||
- **Worker**: Handles inference tasks, downloads models, manages runner processes
|
||||
- **Master**: Coordinates cluster state, places model instances across nodes
|
||||
- **Election**: Bully algorithm for master election
|
||||
@@ -81,7 +81,7 @@ Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
|
||||
- `LOCAL_EVENTS`: Workers send events to master for indexing
|
||||
- `COMMANDS`: Workers/API send commands to master
|
||||
- `ELECTION_MESSAGES`: Election protocol messages
|
||||
- `CONNECTION_MESSAGES`: libp2p connection updates
|
||||
- `CONNECTION_MESSAGES`: zenoh connection updates
|
||||
|
||||
### Event Sourcing
|
||||
The system uses event sourcing for state management:
|
||||
@@ -98,8 +98,8 @@ The system uses event sourcing for state management:
|
||||
|
||||
### Rust Components
|
||||
Rust code in `rust/` provides:
|
||||
- `networking`: libp2p networking (gossipsub, peer discovery)
|
||||
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
|
||||
- `networking`: zenoh networking (gossipsub, peer discovery)
|
||||
- `exo_rs`: PyO3 bindings exposing Rust to Python
|
||||
- `system_custodian`: System-level operations
|
||||
|
||||
### Dashboard
|
||||
|
||||
4182
Cargo.lock
generated
4182
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
60
Cargo.toml
60
Cargo.toml
@@ -1,6 +1,6 @@
|
||||
[workspace]
|
||||
resolver = "3"
|
||||
members = ["rust/networking", "rust/exo_rs", "rust/util"]
|
||||
members = ["rust/exo_rs", "rust/networking"]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.0.1"
|
||||
@@ -20,31 +20,69 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Macro dependecies
|
||||
# pyo3
|
||||
pyo3 = "0.27.2"
|
||||
pyo3-async-runtimes = "0.27.0"
|
||||
pyo3-log = "0.13.2"
|
||||
pyo3-stub-gen = "0.22.2"
|
||||
|
||||
# util
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
|
||||
# Utility dependencies
|
||||
keccak-const = "0.2"
|
||||
nix = "0.31"
|
||||
|
||||
# Async dependencies
|
||||
async-stream = "0.3"
|
||||
tokio = "1.46"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
async-stream = "0.3.6"
|
||||
pin-project = "1.1.10"
|
||||
serde_json = "1.0.149"
|
||||
rand = "0.10.1"
|
||||
parking_lot = "0.12.5"
|
||||
pidfile-rs = "0.3.1"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
env_logger = "0.11.10"
|
||||
|
||||
# networking
|
||||
libp2p = "0.56"
|
||||
libp2p-tcp = "0.44"
|
||||
zenoh = "=1.9.0"
|
||||
zenoh-plugin-storage-manager = { version = "=1.9.0", default-features = false }
|
||||
zenoh-plugin-trait = "=1.9.0"
|
||||
netwatcher = "0.6.0"
|
||||
bytemuck = "1.25.0"
|
||||
|
||||
[patch.crates-io]
|
||||
zenoh = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-buffers = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-codec = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-collections = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-config = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-core = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-crypto = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-keyexpr = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-commons = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-quic = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-quic_datagram = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-tcp = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-tls = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-udp = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-unixsock_stream = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-link-ws = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-macros = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-plugin-trait = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-protocol = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-result = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-runtime = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-sync = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-task = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-transport = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh-util = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
zenoh_backend_traits = { git = "https://github.com/evanev7/zenoh.git", branch = "exo" }
|
||||
|
||||
[workspace.lints.rust]
|
||||
static_mut_refs = "warn" # Or use "warn" instead of deny
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
# Missed things
|
||||
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
|
||||
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
@@ -352,7 +352,7 @@ final class ExoProcessController: ObservableObject {
|
||||
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
environment["EXO_ZENOH_NAMESPACE"] = computeNamespace()
|
||||
if !hfToken.isEmpty {
|
||||
environment["HF_TOKEN"] = hfToken
|
||||
}
|
||||
|
||||
12
flake.lock
generated
12
flake.lock
generated
@@ -47,11 +47,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1775807984,
|
||||
"narHash": "sha256-Redoe3D9zGN5I9QPHWL9vfMVQBehY1fKsMiRXQ83X3w=",
|
||||
"lastModified": 1777708550,
|
||||
"narHash": "sha256-Qif3UXT0l5OQq8H9pRWt4/ia4gF48MWK2oHKL8uVx8U=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "fcf90c0c4d368b2ca917a7afa6d08e98a397e5fd",
|
||||
"rev": "74c1591efaff494756b8d35ebe357c6c2bbdca96",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -218,11 +218,11 @@
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1775745684,
|
||||
"narHash": "sha256-8MbfLwd60FNa8dRFkjE+G3TT/x21G3Rsplm1bMBQUtU=",
|
||||
"lastModified": 1777639980,
|
||||
"narHash": "sha256-6d7Hdurvbjc5uwJuc0YiK7rZBGj6Gs3uzfBFcTs+xCc=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "64ddb549bc9a70d011328746fa46a8883f937b6b",
|
||||
"rev": "64cdaeb06f69b6b769a492edd88b022ae88e8ca2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -146,7 +146,7 @@
|
||||
config.treefmt.build.wrapper
|
||||
|
||||
# PYTHON
|
||||
self'.packages.exo.passthru.evenv
|
||||
#self'.packages.exo.passthru.evenv
|
||||
uv
|
||||
|
||||
# RUST
|
||||
|
||||
@@ -85,13 +85,6 @@ mlx = [
|
||||
{ url = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv/releases/download/mlx_cuda/mlx-0.32.0-cp313-cp313-manylinux_2_35_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
|
||||
{ url = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv/releases/download/mlx_cuda/mlx-0.32.0-cp313-cp313-manylinux_2_35_x86_64.whl", marker = "sys_platform == 'linux' and platform_machine != 'aarch64'" },
|
||||
]
|
||||
mlx-lm = { git = "https://github.com/rltakashige/mlx-lm", branch = "leo/deepseek-v4" }
|
||||
mflux = { git = "https://github.com/evanev7/mflux", branch = "exo2" }
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'linux' and extra == 'mlx-cpu' and extra != 'mlx-cuda13' and extra != 'mlx-cuda12'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' and extra == 'mlx-cuda12' and extra != 'mlx-cuda13' " },
|
||||
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' and extra == 'mlx-cuda13'" },
|
||||
]
|
||||
mlx-cuda-12 = [
|
||||
{ url = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv/releases/download/mlx_cuda/mlx_cuda_12-0.32.0-py3-none-manylinux_2_35_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
|
||||
{ url = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv/releases/download/mlx_cuda/mlx_cuda_12-0.32.0-py3-none-manylinux_2_35_x86_64.whl", marker = "sys_platform == 'linux' and platform_machine != 'aarch64'" },
|
||||
@@ -100,6 +93,13 @@ mlx-cuda-13 = [
|
||||
{ url = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv/releases/download/mlx_cuda/mlx_cuda_13-0.32.0-py3-none-manylinux_2_35_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
|
||||
{ url = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv/releases/download/mlx_cuda/mlx_cuda_13-0.32.0-py3-none-manylinux_2_35_x86_64.whl", marker = "sys_platform == 'linux' and platform_machine != 'aarch64'" },
|
||||
]
|
||||
mlx-lm = { git = "https://github.com/rltakashige/mlx-lm", branch = "leo/deepseek-v4" }
|
||||
mflux = { git = "https://github.com/evanev7/mflux", branch = "exo2" }
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'linux' and extra == 'mlx-cpu' and extra != 'mlx-cuda13' and extra != 'mlx-cuda12'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' and extra == 'mlx-cuda12' and extra != 'mlx-cuda13' " },
|
||||
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' and extra == 'mlx-cuda13'" },
|
||||
]
|
||||
torchvision = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform == 'linux' and extra == 'mlx-cpu' and extra != 'mlx-cuda13' and extra != 'mlx-cuda12'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' and extra == 'mlx-cuda12' and extra != 'mlx-cuda13'" },
|
||||
|
||||
@@ -22,48 +22,32 @@ doc = false
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
networking = { workspace = true }
|
||||
networking.workspace = true
|
||||
extend.workspace = true
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.28.3", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.22.3" }
|
||||
pyo3-async-runtimes = { version = "0.28.0", features = [
|
||||
pyo3 = { workspace = true, features = ["experimental-async"] }
|
||||
pyo3-stub-gen.workspace = true
|
||||
pyo3-async-runtimes = { workspace = true, features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
"testing",
|
||||
] }
|
||||
pyo3-log = "0.13.3"
|
||||
pyo3-log.workspace = true
|
||||
|
||||
pidfile-rs = { git = "https://github.com/AndreiCravtov/pidfile-rs" }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
thiserror = "2.0"
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures-lite = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures-lite.workspace = true
|
||||
pin-project.workspace = true
|
||||
|
||||
# Tracing
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
log.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
zenoh.workspace = true
|
||||
rand.workspace = true
|
||||
serde_json.workspace = true
|
||||
parking_lot.workspace = true
|
||||
|
||||
@@ -6,79 +6,16 @@ import os
|
||||
import pathlib
|
||||
import typing
|
||||
__all__ = [
|
||||
"AllQueuesFullError",
|
||||
"FromSwarm",
|
||||
"Keypair",
|
||||
"MessageTooLargeError",
|
||||
"NetworkingHandle",
|
||||
"NoPeersSubscribedToTopicError",
|
||||
"Pidfile",
|
||||
"PidfileError",
|
||||
"PyFromSwarm",
|
||||
]
|
||||
|
||||
@typing.final
|
||||
class AllQueuesFullError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
class FromSwarm:
|
||||
@typing.final
|
||||
class Connection(FromSwarm):
|
||||
__match_args__ = ("peer_id", "connected",)
|
||||
@property
|
||||
def peer_id(self) -> builtins.str: ...
|
||||
@property
|
||||
def connected(self) -> builtins.bool: ...
|
||||
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> FromSwarm.Connection: ...
|
||||
|
||||
@typing.final
|
||||
class Message(FromSwarm):
|
||||
__match_args__ = ("origin", "topic", "data",)
|
||||
@property
|
||||
def origin(self) -> builtins.str: ...
|
||||
@property
|
||||
def topic(self) -> builtins.str: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> FromSwarm.Message: ...
|
||||
|
||||
...
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Construct an Ed25519 keypair from secret key bytes
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Get the secret key bytes underlying the keypair
|
||||
"""
|
||||
def to_node_id(self) -> builtins.str:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class MessageTooLargeError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair, bootstrap_peers: typing.Sequence[builtins.str], listen_port: builtins.int) -> NetworkingHandle: ...
|
||||
def recv(self) -> typing.Awaitable[FromSwarm]: ...
|
||||
@staticmethod
|
||||
def new(identity: builtins.str, listen_port: builtins.int, discovery_service_port: builtins.int) -> NetworkingHandle: ...
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
@@ -97,12 +34,7 @@ class NetworkingHandle:
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
async def recv(self) -> PyFromSwarm: ...
|
||||
|
||||
@typing.final
|
||||
class Pidfile:
|
||||
@@ -129,6 +61,7 @@ class Pidfile:
|
||||
def __new__(cls, path: builtins.str | os.PathLike | pathlib.Path, mode: builtins.int) -> Pidfile:
|
||||
r"""
|
||||
Creates a new PID file and locks it.
|
||||
Writes the current process ID to the PID file.
|
||||
|
||||
If the PID file cannot be locked, returns `PidfileError::AlreadyRunning` with
|
||||
a PID of the already running process, or `None` if no PID has been written to
|
||||
@@ -160,3 +93,22 @@ class PidfileError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
class PyFromSwarm:
|
||||
@typing.final
|
||||
class Connection(PyFromSwarm):
|
||||
__match_args__ = ("connected",)
|
||||
@property
|
||||
def connected(self) -> builtins.bool: ...
|
||||
def __new__(cls, connected: builtins.bool) -> PyFromSwarm.Connection: ...
|
||||
|
||||
@typing.final
|
||||
class Message(PyFromSwarm):
|
||||
__match_args__ = ("topic", "data",)
|
||||
@property
|
||||
def topic(self) -> builtins.str: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def __new__(cls, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
|
||||
|
||||
...
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@ build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "exo_rs"
|
||||
version = "0.2.16"
|
||||
version = "0.3.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
{ name = "Evan Quiney", email = "evanev7@gmail.com" },
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
@@ -18,8 +18,6 @@ dependencies = []
|
||||
dev = ["exo_rs", "pytest>=8.4.0", "pytest-asyncio>=1.0.0"]
|
||||
|
||||
[tool.maturin]
|
||||
#purelib = true
|
||||
#python-source = "python"
|
||||
module-name = "exo_rs"
|
||||
features = ["pyo3/extension-module", "pyo3/experimental-async"]
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
@@ -8,7 +7,7 @@ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
pub struct PyKeypair(pub u128);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
@@ -17,31 +16,29 @@ impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
Self(rand::random())
|
||||
}
|
||||
|
||||
/// Construct an Ed25519 keypair from secret key bytes
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(u128::from_le_bytes(
|
||||
bytes
|
||||
.try_into()
|
||||
.map_err(|_| "passed too many bytes to from_bytes")
|
||||
.pyerr()?,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Get the secret key bytes underlying the keypair
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self
|
||||
.0
|
||||
.clone()
|
||||
.try_into_ed25519()
|
||||
.pyerr()?
|
||||
.secret()
|
||||
.as_ref()
|
||||
.to_vec();
|
||||
let bytes = self.0.to_le_bytes();
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
fn to_node_id(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
format!("{:x}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,23 +5,16 @@
|
||||
//!
|
||||
|
||||
mod allow_threading;
|
||||
mod ident;
|
||||
mod networking;
|
||||
mod pidfile;
|
||||
// mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pidfile::pidfile_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::types::PyModuleMethods;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3::{Bound, PyResult, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
@@ -153,7 +146,7 @@ pub(crate) mod ext {
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule(name = "exo_rs", gil_used = true)]
|
||||
#[pymodule(name = "exo_rs")]
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
@@ -164,9 +157,10 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
m.add_class::<PyKeypair>()?;
|
||||
networking_submodule(m)?;
|
||||
pidfile_submodule(m)?;
|
||||
// m.add_class::<PyKeypair>()?;
|
||||
// networking_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
@@ -1,166 +1,40 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::exception::{
|
||||
PyAllQueuesFullError, PyMessageTooLargeError, PyNoPeersSubscribedToTopicError,
|
||||
};
|
||||
use crate::pyclass;
|
||||
use futures_lite::{Stream, StreamExt as _};
|
||||
use libp2p::gossipsub::PublishError;
|
||||
use networking::swarm::{FromSwarm, ToSwarm, create_swarm};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use networking::Session;
|
||||
use networking::swarm::{FromSwarm, Swarm, ToSwarm, create_swarm};
|
||||
use pyo3::exceptions::{PyRuntimeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods};
|
||||
use pyo3_stub_gen::derive::{
|
||||
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
|
||||
};
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
pub struct PyNoPeersSubscribedToTopicError {}
|
||||
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
const MSG: &'static str = "\
|
||||
No peers are currently subscribed to receive messages on this topic. \
|
||||
Wait for peers to subscribe or check your network connectivity.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
|
||||
pub struct PyAllQueuesFullError {}
|
||||
|
||||
impl PyAllQueuesFullError {
|
||||
const MSG: &'static str =
|
||||
"All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAllQueuesFullError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="MessageTooLargeError")]
|
||||
pub struct PyMessageTooLargeError {}
|
||||
|
||||
impl PyMessageTooLargeError {
|
||||
const MSG: &'static str = "Gossipsub message exceeds max_transmit_size. Reduce prompt length or increase the limit.";
|
||||
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(())
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyMessageTooLargeError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("MessageTooLargeError(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
struct PyNetworkingHandle {
|
||||
pub struct PyNetworkingHandle {
|
||||
// channels
|
||||
pub to_swarm: mpsc::Sender<ToSwarm>,
|
||||
pub swarm: Arc<Mutex<Pin<Box<dyn Stream<Item = FromSwarm> + Send>>>>,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass_complex_enum]
|
||||
#[pyclass(name = "FromSwarm")]
|
||||
enum PyFromSwarm {
|
||||
Connection {
|
||||
peer_id: String,
|
||||
connected: bool,
|
||||
},
|
||||
Message {
|
||||
origin: String,
|
||||
topic: String,
|
||||
data: Py<PyBytes>,
|
||||
},
|
||||
#[pyclass]
|
||||
pub enum PyFromSwarm {
|
||||
Connection { connected: bool },
|
||||
Message { topic: String, data: Py<PyBytes> },
|
||||
}
|
||||
impl From<FromSwarm> for PyFromSwarm {
|
||||
fn from(value: FromSwarm) -> Self {
|
||||
match value {
|
||||
FromSwarm::Discovered { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: true,
|
||||
},
|
||||
FromSwarm::Expired { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: false,
|
||||
},
|
||||
FromSwarm::Message { from, topic, data } => Self::Message {
|
||||
origin: from.to_base58(),
|
||||
FromSwarm::Discovered {} => Self::Connection { connected: true },
|
||||
FromSwarm::Expired {} => Self::Connection { connected: false },
|
||||
FromSwarm::Message { topic, data } => Self::Message {
|
||||
topic: topic,
|
||||
data: data.pybytes(),
|
||||
},
|
||||
@@ -168,6 +42,20 @@ impl From<FromSwarm> for PyFromSwarm {
|
||||
}
|
||||
}
|
||||
|
||||
impl PyNetworkingHandle {
|
||||
pub fn from_session(session: Session) -> Self {
|
||||
let (to_swarm, from_client) = mpsc::channel(1024);
|
||||
let swarm = Swarm {
|
||||
from_client,
|
||||
session,
|
||||
};
|
||||
PyNetworkingHandle {
|
||||
swarm: Arc::new(Mutex::new(swarm.into_stream())),
|
||||
to_swarm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNetworkingHandle {
|
||||
@@ -177,36 +65,49 @@ impl PyNetworkingHandle {
|
||||
|
||||
// ---- Lifecycle management methods ----
|
||||
|
||||
#[new]
|
||||
#[pyo3(signature = (identity, bootstrap_peers, listen_port))]
|
||||
fn py_new(
|
||||
identity: Bound<'_, PyKeypair>,
|
||||
bootstrap_peers: Vec<String>,
|
||||
#[staticmethod]
|
||||
pub fn new<'py>(
|
||||
identity: &str,
|
||||
listen_port: u16,
|
||||
) -> PyResult<Self> {
|
||||
discovery_service_port: u16,
|
||||
) -> PyResult<PyNetworkingHandle> {
|
||||
// todo: zenoh self assigned peers
|
||||
if listen_port == 0 {
|
||||
todo!();
|
||||
}
|
||||
// create communication channels
|
||||
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (to_swarm, from_client) = mpsc::channel(1024);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
if !identity
|
||||
.chars()
|
||||
.all(|c| ('0'..='9').contains(&c) || ('a'..='f').contains(&c))
|
||||
|| identity.len() > 32
|
||||
{
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"{identity} is not a valid zenoh identity"
|
||||
)));
|
||||
}
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
|
||||
let swarm = create_swarm(identity, from_client, bootstrap_peers, listen_port)
|
||||
.pyerr()?
|
||||
.into_stream();
|
||||
let swarm = pyo3_async_runtimes::tokio::get_runtime()
|
||||
.block_on(create_swarm(
|
||||
identity,
|
||||
from_client,
|
||||
listen_port,
|
||||
discovery_service_port,
|
||||
))
|
||||
.pyerr()?;
|
||||
|
||||
Ok(Self {
|
||||
swarm: Arc::new(Mutex::new(swarm)),
|
||||
Ok(PyNetworkingHandle {
|
||||
swarm: Arc::new(Mutex::new(swarm.into_stream())),
|
||||
to_swarm,
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(override_return_type(
|
||||
type_repr="typing.Awaitable[FromSwarm]", imports=("typing")
|
||||
))]
|
||||
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
let swarm = self.swarm.clone();
|
||||
#[gen_stub(skip)]
|
||||
pub fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
let swarm = Arc::clone(&self.swarm);
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
swarm
|
||||
.try_lock()
|
||||
@@ -223,7 +124,7 @@ impl PyNetworkingHandle {
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
pub async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
@@ -245,7 +146,7 @@ impl PyNetworkingHandle {
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
pub async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
@@ -266,7 +167,7 @@ impl PyNetworkingHandle {
|
||||
/// Publishes a message with multiple topics to the `GossipSub` network.
|
||||
///
|
||||
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
pub async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
@@ -285,23 +186,21 @@ impl PyNetworkingHandle {
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.map_err(|e| match e {
|
||||
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
|
||||
PublishError::MessageTooLarge => PyMessageTooLargeError::new_err(),
|
||||
PublishError::NoPeersSubscribedToTopic => {
|
||||
PyNoPeersSubscribedToTopicError::new_err()
|
||||
}
|
||||
e => PyRuntimeError::new_err(e.to_string()),
|
||||
})?;
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
m.add_class::<exception::PyMessageTooLargeError>()?;
|
||||
pyo3_stub_gen::inventory::submit! {
|
||||
gen_methods_from_python! {
|
||||
r#"
|
||||
class PyNetworkingHandle:
|
||||
async def recv() -> PyFromSwarm: ...
|
||||
"#
|
||||
}
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
m.add_class::<PyFromSwarm>()?;
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ impl PyPidfile {
|
||||
#[pymethods]
|
||||
impl PyPidfile {
|
||||
/// Creates a new PID file and locks it.
|
||||
/// Writes the current process ID to the PID file.
|
||||
///
|
||||
/// If the PID file cannot be locked, returns `PidfileError::AlreadyRunning` with
|
||||
/// a PID of the already running process, or `None` if no PID has been written to
|
||||
|
||||
@@ -1,31 +1,28 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.capture import CaptureFixture
|
||||
from exo_rs import (
|
||||
Keypair,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
Pidfile,
|
||||
FromSwarm,
|
||||
PyFromSwarm,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep_on_multiple_items() -> None:
|
||||
print("PYTHON: starting handle")
|
||||
h = NetworkingHandle(Keypair.generate(), [], 0)
|
||||
h = NetworkingHandle.new(os.urandom(16), [], 0)
|
||||
print("PYTHON: handle started")
|
||||
|
||||
rt = asyncio.create_task(_await_recv(h))
|
||||
|
||||
# sleep for 4 ticks
|
||||
for i in range(4):
|
||||
for i in range(10):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
await h.gossipsub_publish("topic", b"somehting or other")
|
||||
except NoPeersSubscribedToTopicError as e:
|
||||
print("caught it", e)
|
||||
await h.gossipsub_publish("topic", b"somehting or other")
|
||||
|
||||
|
||||
def test_pidfile(capsys: CaptureFixture[str]):
|
||||
@@ -39,11 +36,15 @@ async def _await_recv(h: NetworkingHandle):
|
||||
while True:
|
||||
event = await h.recv()
|
||||
match event:
|
||||
case FromSwarm.Connection() as c:
|
||||
case PyFromSwarm.Connection() as c:
|
||||
print(f"PYTHON: connection update: {c}")
|
||||
case FromSwarm.Message() as m:
|
||||
case PyFromSwarm.Message() as m:
|
||||
print(f"PYTHON: message: {m}")
|
||||
|
||||
|
||||
def scoped_lock_file():
|
||||
a = Pidfile("/tmp/lock.pid", 0o0600)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_sleep_on_multiple_items())
|
||||
|
||||
@@ -1,42 +1,24 @@
|
||||
[package]
|
||||
name = "networking"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "networking"
|
||||
path = "src/lib.rs"
|
||||
[dependencies]
|
||||
async-stream.workspace = true
|
||||
futures-lite.workspace = true
|
||||
netwatcher = { workspace = true, features = ["tokio"] }
|
||||
parking_lot.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
zenoh = { workspace = true, features = ["internal", "plugins", "unstable"] }
|
||||
zenoh-plugin-storage-manager.workspace = true
|
||||
zenoh-plugin-trait.workspace = true
|
||||
rand.workspace = true
|
||||
log.workspace = true
|
||||
bytemuck = { workspace = true, features = ["derive"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
async-stream = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = [
|
||||
"default",
|
||||
"env-filter",
|
||||
] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
tracing = "0.1.44"
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::identity;
|
||||
use networking::swarm;
|
||||
use networking::swarm::{FromSwarm, ToSwarm};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (to_swarm, from_client) = mpsc::channel(20);
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm = swarm::create_swarm(
|
||||
identity::Keypair::generate_ed25519(),
|
||||
from_client,
|
||||
vec![],
|
||||
0,
|
||||
)
|
||||
.expect("Swarm creation failed")
|
||||
.into_stream();
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let (tx, rx) = oneshot::channel();
|
||||
_ = to_swarm
|
||||
.send(ToSwarm::Subscribe {
|
||||
topic: "test-net".to_string(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
.expect("should send");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
rx.await
|
||||
.expect("tx not dropped")
|
||||
.expect("subscribe shouldn't fail");
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = to_swarm
|
||||
.send(swarm::ToSwarm::Publish {
|
||||
topic: "test-net".to_string(),
|
||||
data: line.as_bytes().to_vec(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
{
|
||||
println!("Send error: {e:?}");
|
||||
return;
|
||||
};
|
||||
match rx.await {
|
||||
Ok(Err(e)) => println!("Publish error: {e:?}"),
|
||||
Err(e) => println!("Publish error: {e:?}"),
|
||||
Ok(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
// on gossipsub outgoing
|
||||
match swarm.next().await {
|
||||
// on gossipsub incoming
|
||||
Some(FromSwarm::Discovered { peer_id }) => {
|
||||
println!("\n\nconnected to {peer_id}\n\n")
|
||||
}
|
||||
Some(FromSwarm::Expired { peer_id }) => {
|
||||
println!("\n\ndisconnected from {peer_id}\n\n")
|
||||
}
|
||||
Some(FromSwarm::Message { from, topic, data }) => {
|
||||
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
34
rust/networking/examples/do_nothing.rs
Normal file
34
rust/networking/examples/do_nothing.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use networking;
|
||||
use tracing::{info, warn};
|
||||
use zenoh::{Result, Wait};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
zenoh::init_log_from_env_or("info");
|
||||
info!("Opening session...");
|
||||
let cfg = networking::cfg(&format!("{:x}", rand::random::<u128>()), 52414)?;
|
||||
let session = networking::open(cfg, 52414, 52413).await?;
|
||||
let _tok = session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_token(format!("nodes/{}/live", session.z.zid()))
|
||||
.wait()?;
|
||||
let subs = session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_subscriber("**")
|
||||
.history(true)
|
||||
.wait()?;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = tokio::signal::ctrl_c() => break,
|
||||
s = subs.recv_async() => {
|
||||
match s {
|
||||
Err(e) => warn!("{e}"),
|
||||
Ok(s) => info!("{}: {}", s.kind(), s.key_expr().to_string().split("/").nth(1).unwrap()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
32
rust/networking/examples/heartbeat.rs
Normal file
32
rust/networking/examples/heartbeat.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use env_logger::Env;
|
||||
use log::info;
|
||||
use networking;
|
||||
use zenoh::{Result, Wait};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env_logger::try_init_from_env(Env::new().default_filter_or("info")).expect("logger failed");
|
||||
info!("Opening session...");
|
||||
let cfg = networking::cfg(&format!("{:x}", rand::random::<u128>()), 52414)?;
|
||||
let session = networking::open(cfg, 52414, 52413).await?;
|
||||
let _tok = session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_token(format!("nodes/{}/live", session.z.zid()))
|
||||
.wait()?;
|
||||
session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_subscriber("**")
|
||||
.history(true)
|
||||
.callback(|tok| info!("{}: {}", tok.kind(), tok.key_expr().to_string()))
|
||||
.background()
|
||||
.wait()?;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = tokio::signal::ctrl_c() => break,
|
||||
_ = session.z.put("hello", "world") => {},
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
48
rust/networking/examples/put_string.rs
Normal file
48
rust/networking/examples/put_string.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use std::{env, time::Duration};
|
||||
|
||||
use env_logger::Env;
|
||||
use log::info;
|
||||
use networking;
|
||||
use zenoh::Result;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env_logger::try_init_from_env(Env::new().default_filter_or("info")).expect("logger failed");
|
||||
let n_bytes = env::args()
|
||||
.nth(1)
|
||||
.and_then(|it| it.parse::<usize>().ok())
|
||||
.expect("USAGE: put_string <n> -- pub a string of n bytes into stream/data");
|
||||
info!("Opening session...");
|
||||
let cfg = networking::cfg(&format!("{:x}", rand::random::<u128>()), 52414)?;
|
||||
let session = networking::open(cfg, 52414, 52413).await?;
|
||||
let _tok = session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_token(format!("nodes/{}/live", session.z.zid()))
|
||||
.await?;
|
||||
let key_expr = "stream/data";
|
||||
let payload = "n".repeat(n_bytes);
|
||||
|
||||
let pubs = session
|
||||
.z
|
||||
.declare_publisher(key_expr)
|
||||
.congestion_control(zenoh::qos::CongestionControl::Block)
|
||||
.await?;
|
||||
let pubs_l = pubs.matching_listener().await?;
|
||||
if !pubs.matching_status().await?.matching() {
|
||||
while !pubs_l.recv_async().await?.matching() {}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
info!("Putting Data ('{key_expr}': '{}')...", payload.len());
|
||||
for _ in 0..10 {
|
||||
let t = tokio::time::Instant::now();
|
||||
for _ in 0..5000 {
|
||||
pubs.put(payload.clone()).await?;
|
||||
}
|
||||
info!("{:?}", t.elapsed());
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
tokio::signal::ctrl_c().await?;
|
||||
Ok(())
|
||||
}
|
||||
74
rust/networking/examples/serve_storage.rs
Normal file
74
rust/networking/examples/serve_storage.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use env_logger::Env;
|
||||
use log::info;
|
||||
use networking;
|
||||
use zenoh::Result;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env_logger::try_init_from_env(Env::new().default_filter_or("info")).expect("logger failed");
|
||||
info!("Opening session...");
|
||||
let cfg = networking::cfg(&format!("{:x}", rand::random::<u128>()), 52414)?;
|
||||
let session = networking::open(cfg, 52414, 52413).await?;
|
||||
let _tok = session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_token(format!("nodes/{}/live", session.z.zid()))
|
||||
.await?;
|
||||
let _sub = session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_subscriber("nodes/*/live")
|
||||
.history(true)
|
||||
.callback(|tok| {
|
||||
info!(
|
||||
"{}: {}",
|
||||
tok.kind(),
|
||||
tok.key_expr()
|
||||
.to_string()
|
||||
.strip_prefix("nodes/")
|
||||
.and_then(|it| it.strip_suffix("/live"))
|
||||
.unwrap()
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let watch = async {
|
||||
for _ in 0..1000 {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
session
|
||||
.z
|
||||
.get("**")
|
||||
.callback(|reply| {
|
||||
let sample = reply.into_result().expect("no errs");
|
||||
info!(
|
||||
"got {} bytes on {}",
|
||||
sample.payload().len(),
|
||||
sample.key_expr()
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
Result::<()>::Ok(())
|
||||
};
|
||||
let subs = session.z.declare_subscriber("**").await?;
|
||||
|
||||
let mut i = 0;
|
||||
let _a = async {
|
||||
while let Ok(sample) = subs.recv_async().await {
|
||||
i += 1;
|
||||
info!(
|
||||
"[{i}] received {} bytes on {}",
|
||||
sample.payload().len(),
|
||||
sample.key_expr()
|
||||
)
|
||||
}
|
||||
};
|
||||
tokio::select! {
|
||||
_ = watch => {},
|
||||
_ = _a => {},
|
||||
_ = tokio::signal::ctrl_c() => {},
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
41
rust/networking/examples/z_get.rs
Normal file
41
rust/networking/examples/z_get.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::{borrow::Cow, env};
|
||||
|
||||
use env_logger::Env;
|
||||
use log::{info, warn};
|
||||
use networking;
|
||||
use zenoh::{Result, Wait};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env_logger::try_init_from_env(Env::new().default_filter_or("info")).expect("logger failed");
|
||||
info!("Opening session...");
|
||||
let cfg = networking::cfg(&format!("{:x}", rand::random::<u128>()), 52414)?;
|
||||
let session = networking::open(cfg, 52414, 52413).await?;
|
||||
let other_live = session
|
||||
.z
|
||||
.liveliness()
|
||||
.declare_subscriber("**")
|
||||
.history(true)
|
||||
.wait()?;
|
||||
_ = other_live.recv_async().await?;
|
||||
let other_live = session.z.liveliness().get("**").wait()?;
|
||||
while let Ok(s) = other_live.recv_async().await {
|
||||
info!("{s:?}");
|
||||
}
|
||||
let query = env::args().nth(1).expect("USAGE: z_get [query]");
|
||||
info!("Querying {query}");
|
||||
let subs = session.z.liveliness().get(query).await?;
|
||||
while let Ok(r) = subs.recv_async().await {
|
||||
match r.into_result() {
|
||||
Ok(s) => info!(
|
||||
"{}: {}",
|
||||
s.key_expr(),
|
||||
s.payload()
|
||||
.try_to_string()
|
||||
.unwrap_or_else(|_| Cow::Borrowed("-bytes-"))
|
||||
),
|
||||
Err(e) => warn!("{e}"),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
|
||||
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
|
||||
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
|
||||
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
|
||||
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
|
||||
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
|
||||
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
|
||||
|
||||
https://stackoverflow.com/questions/17169298/af-packet-on-osx
|
||||
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
|
||||
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
|
||||
|
||||
https://www.unix.com/man_page/mojave/4/ip/
|
||||
^OS X manpages for IP
|
||||
|
||||
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
|
||||
^driver kit, system extensions & kexts for macOS
|
||||
|
||||
----
|
||||
|
||||
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
|
||||
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
|
||||
----> here is a guide on how to set up thunderbolt-ethernet on linux
|
||||
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
|
||||
|
||||
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
|
||||
^GPT discussion about making socket interface
|
||||
|
||||
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
|
||||
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
|
||||
^GPT discussion about accessing TB on MacOS low level interactions
|
||||
|
||||
--------------------------------
|
||||
|
||||
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
|
||||
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
|
||||
|
||||
|
||||
---------------------------------
|
||||
|
||||
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
|
||||
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
|
||||
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
|
||||
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c
|
||||
@@ -1,390 +1,330 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures_lite::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
use std::{
|
||||
env,
|
||||
hash::{DefaultHasher, Hash, Hasher},
|
||||
io,
|
||||
net::{Ipv6Addr, SocketAddr, SocketAddrV6},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use log::{debug, trace, warn};
|
||||
use netwatcher::WatchHandle;
|
||||
use parking_lot::Mutex;
|
||||
use tokio::{
|
||||
net::UdpSocket,
|
||||
time::{Interval, interval},
|
||||
};
|
||||
use zenoh::config::ZenohId;
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
const GROUP: Ipv6Addr = Ipv6Addr::new(0xff12, 0, 0, 0, 0, 0, 0xe0a1, 0xde89);
|
||||
const MAGIC: [u8; 3] = *b"EXO";
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
pub struct Discovery {
|
||||
sock: Arc<UdpSocket>,
|
||||
ifaces: Arc<Mutex<Vec<SocketAddrV6>>>,
|
||||
namespace: [u8; 8],
|
||||
last_nonce: Mutex<[u8; 8]>,
|
||||
/// the port of the service we are doing discovery for - transmitted to peers
|
||||
listen_port: u16,
|
||||
zid: ZenohId,
|
||||
tick: Interval,
|
||||
_sync: Mutex<WatchHandle>,
|
||||
}
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Discovered {
|
||||
pub zid: ZenohId,
|
||||
pub addr: SocketAddrV6,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
impl Discovery {
|
||||
pub async fn new(zid: ZenohId, listen_port: u16, discovery_port: u16) -> io::Result<Self> {
|
||||
let namespace = {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
env::var("EXO_ZENOH_NAMESPACE")
|
||||
.unwrap_or_else(|_| "exo".to_string())
|
||||
.hash(&mut hasher);
|
||||
hasher.finish().to_le_bytes()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
bootstrap_peers: Vec<Multiaddr>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair, bootstrap_peers: Vec<Multiaddr>) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
bootstrap_peers,
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
let sock = Arc::new(UdpSocket::bind(format!("[::]:{discovery_port}")).await?);
|
||||
//sock.set_multicast_loop_v6(false)?;
|
||||
let ifaces: Arc<Mutex<Vec<SocketAddrV6>>> = Default::default();
|
||||
let _sync = Mutex::new(
|
||||
netwatcher::watch_interfaces_with_callback({
|
||||
let sock = sock.clone();
|
||||
let ifaces = ifaces.clone();
|
||||
move |update| {
|
||||
for (iface_idx, iface) in update.interfaces.iter() {
|
||||
if iface
|
||||
.ipv6_ips()
|
||||
.all(|addr| addr.is_loopback() || addr.is_unspecified())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
match sock.join_multicast_v6(&GROUP, *iface_idx) {
|
||||
Ok(()) => ifaces
|
||||
.lock()
|
||||
.push(SocketAddrV6::new(GROUP, 52413, 0, *iface_idx)),
|
||||
Err(e) if e.kind() != io::ErrorKind::AddrInUse => {
|
||||
// skip AddrInUse - just means we've already joined the mv6
|
||||
if let Some(iface) = update.interfaces.get(&iface_idx) {
|
||||
warn!(
|
||||
"failed to join multicast v6 for interface {}: {e}",
|
||||
iface.name
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for iface_idx in update.diff.removed {
|
||||
ifaces.lock().retain(|addr| addr.scope_id() != iface_idx);
|
||||
|
||||
if let Err(e) = sock.leave_multicast_v6(&GROUP, iface_idx) {
|
||||
if let Some(iface) = update.interfaces.get(&iface_idx) {
|
||||
warn!(
|
||||
"failed to leave multicast v6 for interface {}: {e}",
|
||||
iface.name
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
// todo: better error handling here
|
||||
.expect("failed to bind discovery watcher"),
|
||||
);
|
||||
Ok(Self {
|
||||
sock,
|
||||
namespace,
|
||||
ifaces,
|
||||
last_nonce: Mutex::new(rand::random()),
|
||||
listen_port,
|
||||
zid,
|
||||
tick: interval(Duration::from_secs(1)),
|
||||
_sync,
|
||||
})
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
pub async fn next(&mut self) -> io::Result<Discovered> {
|
||||
let mut buf = [0u8; Hello::buf_size() + WhatsUp::buf_size() + 1];
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = self.tick.tick() => {
|
||||
self.announce().await?;
|
||||
}
|
||||
res = self.sock.recv_from(&mut buf) => {
|
||||
let Ok((bytes_read, addr)) = res else { continue; };
|
||||
if let Some(discovered) = self.respond(bytes_read, addr, &buf).await? {
|
||||
return Ok(discovered)
|
||||
}
|
||||
}
|
||||
}
|
||||
// dial bootstrap peers (for environments where mDNS is unavailable)
|
||||
for addr in &self.bootstrap_peers {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::unknown_peer_id().address(addr.clone()).build(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn respond(
|
||||
&self,
|
||||
bytes_read: usize,
|
||||
addr: SocketAddr,
|
||||
buf: &[u8],
|
||||
) -> io::Result<Option<Discovered>> {
|
||||
trace!(
|
||||
"raw recv: {bytes_read} bytes from {addr}: {:02x?}",
|
||||
&buf[..bytes_read]
|
||||
);
|
||||
if bytes_read < size_of::<Header>() {
|
||||
trace!("dropped: early EOF");
|
||||
return Ok(None);
|
||||
}
|
||||
let header: &Header = bytemuck::from_bytes(&buf[0..size_of::<Header>()]);
|
||||
if header.magic != MAGIC {
|
||||
trace!("dropped: wrong magic");
|
||||
return Ok(None);
|
||||
}
|
||||
let Ok(kind) = header.kind.try_into() else {
|
||||
trace!("dropped: unknown message kind {}", header.kind);
|
||||
return Ok(None);
|
||||
};
|
||||
match kind {
|
||||
Kind::Hello => {
|
||||
let total = Hello::buf_size();
|
||||
if bytes_read != total {
|
||||
trace!("dropped: hello wrong size");
|
||||
return Ok(None);
|
||||
}
|
||||
let hello: &Hello = bytemuck::from_bytes(&buf[size_of::<Header>()..total]);
|
||||
if hello.nonce == *self.last_nonce.lock() {
|
||||
trace!("dropped: local hello nonce");
|
||||
return Ok(None);
|
||||
}
|
||||
if hello.namespace != self.namespace {
|
||||
trace!("dropped: different namespace");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// reply
|
||||
trace!("replying to Hello({:?})", hello.nonce);
|
||||
let reply = WhatsUp {
|
||||
nonce: hello.nonce,
|
||||
zid: self.zid.to_le_bytes(),
|
||||
port_le: self.listen_port.to_le_bytes(),
|
||||
}
|
||||
.alloc();
|
||||
|
||||
for i in 1..6 {
|
||||
if self
|
||||
.sock
|
||||
.send_to(&reply, addr)
|
||||
.await
|
||||
.inspect_err(|e| debug!("send to {addr} failed: {e}"))
|
||||
.is_ok_and(|sent| sent == WhatsUp::buf_size())
|
||||
{
|
||||
trace!(
|
||||
"sent {} bytes to {addr} after {} attempt(s)",
|
||||
WhatsUp::buf_size(),
|
||||
i
|
||||
);
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
Kind::WhatsUp => {
|
||||
let total = WhatsUp::buf_size();
|
||||
if bytes_read != total {
|
||||
trace!("dropped: whatsup wrong size");
|
||||
return Ok(None);
|
||||
}
|
||||
let whats_up: &WhatsUp = bytemuck::from_bytes(&buf[size_of::<Header>()..total]);
|
||||
if whats_up.nonce != *self.last_nonce.lock() {
|
||||
trace!("dropped: stale nonce");
|
||||
return Ok(None);
|
||||
}
|
||||
let SocketAddr::V6(v6) = addr else {
|
||||
trace!("dropped: v4 addr used");
|
||||
return Ok(None);
|
||||
};
|
||||
let Ok(zid) = ZenohId::try_from(&whats_up.zid[..]) else {
|
||||
trace!("dropped: zenoh conversion failed");
|
||||
return Ok(None);
|
||||
};
|
||||
if zid == self.zid {
|
||||
trace!("dropped: self zenoh id");
|
||||
return Ok(None);
|
||||
}
|
||||
// discovery success!
|
||||
// the incoming port is our listen port;
|
||||
// overwrite it with the whats_up port corresponding to the remote zenoh service
|
||||
let addr = {
|
||||
let mut x = v6;
|
||||
x.set_port(u16::from_le_bytes(whats_up.port_le));
|
||||
x
|
||||
};
|
||||
Ok(Some(Discovered { addr, zid }))
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
async fn announce(&self) -> io::Result<()> {
|
||||
let nonce = rand::random();
|
||||
*self.last_nonce.lock() = nonce;
|
||||
let buf = Hello {
|
||||
nonce,
|
||||
namespace: self.namespace,
|
||||
}
|
||||
.alloc();
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
let addrs = self.ifaces.lock().clone();
|
||||
debug!("announcing Hello({nonce:?}) to {addrs:?}");
|
||||
// rev so .remove() doesn't break things
|
||||
for (i, addr) in addrs.into_iter().enumerate().rev() {
|
||||
match self.sock.send_to(&buf, addr).await {
|
||||
Ok(bytes) => trace!("sent {bytes} to {addr}"),
|
||||
Err(e) if e.kind() == io::ErrorKind::HostUnreachable => {
|
||||
debug!("disabling discovery address {addr}: {e}");
|
||||
_ = self.ifaces.lock().swap_remove(i);
|
||||
}
|
||||
Err(e) => debug!("failed to reach {addr}: {e}"),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
// packet & version
|
||||
pub enum Kind {
|
||||
Hello = 0,
|
||||
WhatsUp = 1,
|
||||
}
|
||||
|
||||
pub struct UnknownKind;
|
||||
impl TryFrom<u8> for Kind {
|
||||
type Error = UnknownKind;
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0 => Ok(Self::Hello),
|
||||
1 => Ok(Self::WhatsUp),
|
||||
_ => Err(UnknownKind),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Message: Pod {
|
||||
const KIND: Kind;
|
||||
}
|
||||
// should be part of the Message trait, but const in traits isnt stabilized. this lets alloc :: Self -> [u8; Self::buf_size()]
|
||||
macro_rules! impl_alloc {
|
||||
($a:ident) => {
|
||||
impl $a {
|
||||
const fn buf_size() -> usize {
|
||||
size_of::<Header>() + size_of::<Self>()
|
||||
}
|
||||
pub fn alloc(self) -> [u8; Self::buf_size()] {
|
||||
let mut buf = [0u8; Self::buf_size()];
|
||||
buf[0..size_of::<Header>()].copy_from_slice(bytemuck::bytes_of(&Header {
|
||||
magic: MAGIC,
|
||||
kind: Self::KIND as u8,
|
||||
}));
|
||||
buf[size_of::<Header>()..Self::buf_size()]
|
||||
.copy_from_slice(bytemuck::bytes_of(&self));
|
||||
buf
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct Header {
|
||||
magic: [u8; 3],
|
||||
kind: u8,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct Hello {
|
||||
pub nonce: [u8; 8],
|
||||
pub namespace: [u8; 8],
|
||||
}
|
||||
impl Message for Hello {
|
||||
const KIND: Kind = Kind::Hello;
|
||||
}
|
||||
impl_alloc!(Hello);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct WhatsUp {
|
||||
pub nonce: [u8; 8],
|
||||
pub zid: [u8; 16],
|
||||
pub port_le: [u8; 2],
|
||||
}
|
||||
impl Message for WhatsUp {
|
||||
const KIND: Kind = Kind::WhatsUp;
|
||||
}
|
||||
impl_alloc!(WhatsUp);
|
||||
|
||||
@@ -1,44 +1,105 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
use zenoh::{Result, Session as ZSession, config::Locator};
|
||||
use zenoh_plugin_storage_manager::StoragesPlugin;
|
||||
use zenoh_plugin_trait::PluginsManager;
|
||||
|
||||
pub use zenoh::{Config, config::ZenohId};
|
||||
|
||||
use crate::discovery::Discovery;
|
||||
|
||||
pub mod discovery;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
pub fn cfg(identity: &str, listen_port: u16) -> Result<zenoh::Config> {
|
||||
assert!(
|
||||
identity
|
||||
.chars()
|
||||
.all(|c| ('0'..='9').contains(&c) || ('a'..='f').contains(&c))
|
||||
);
|
||||
assert!(identity.len() <= 32);
|
||||
assert!(listen_port != 0, "must used defined listen port");
|
||||
let mut cfg = zenoh::Config::default();
|
||||
// todo: cleanup
|
||||
cfg.insert_json5("id", &format!("\"{identity}\""))?;
|
||||
cfg.insert_json5("mode", "\"router\"")?;
|
||||
cfg.insert_json5("listen/endpoints", &format!("[\"tcp/[::]:{listen_port}\"]"))?;
|
||||
cfg.insert_json5("scouting/multicast/enabled", "false")?;
|
||||
cfg.insert_json5("scouting/multicast/autoconnect", "[]")?;
|
||||
cfg.insert_json5("scouting/gossip/multihop", "true")?;
|
||||
cfg.insert_json5("adminspace/enabled", "true")?;
|
||||
//cfg.insert_json5("transport/link/tx/batch_size", "9216")?;
|
||||
cfg.insert_json5("transport/link/rx/buffer_size", "16777216")?;
|
||||
//cfg.insert_json5("timestamping/enabled", "true")?;
|
||||
cfg.insert_json5("plugins/storage_manager/__required__", "true")?;
|
||||
cfg.insert_json5(
|
||||
"plugins/storage_manager/storages/mem1",
|
||||
r#"{
|
||||
key_expr: "storage/mem1/**",
|
||||
strip_prefix: "storage/mem1",
|
||||
volume: "memory",
|
||||
replication: {
|
||||
interval: 2,
|
||||
}
|
||||
}"#,
|
||||
)?;
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use std::net::IpAddr;
|
||||
pub async fn open(
|
||||
cfg: zenoh::Config,
|
||||
listen_port: u16,
|
||||
discovery_service_port: u16,
|
||||
) -> Result<Session> {
|
||||
assert!(listen_port != 0, "must used defined listen port");
|
||||
let mut plugins = PluginsManager::static_plugins_only();
|
||||
plugins.declare_static_plugin::<StoragesPlugin, _>("storage_manager", true);
|
||||
let mut runtime = zenoh::internal::runtime::RuntimeBuilder::new(cfg)
|
||||
.plugins_manager(plugins)
|
||||
.build()
|
||||
.await?;
|
||||
let z = zenoh::session::init(runtime.clone().into()).await?;
|
||||
runtime.start().await?;
|
||||
let mut discovery = Discovery::new(z.zid(), listen_port, discovery_service_port).await?;
|
||||
let _jh = Arc::new(AbortOnDrop(tokio::task::spawn(async move {
|
||||
loop {
|
||||
let Ok(discovered) = discovery.next().await.inspect_err(|e| {
|
||||
log::warn!("discovery error {e}");
|
||||
}) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
/// If the multiaddress corresponds to a TCP address, extracts it
|
||||
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
|
||||
let mut ps = self.into_iter();
|
||||
let ip = if let Some(p) = ps.next() {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None,
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
if discovered.zid > runtime.zid() {
|
||||
log::debug!("not connecting to peer with greater zid");
|
||||
continue;
|
||||
}
|
||||
|
||||
let Ok(locator) =
|
||||
Locator::new("tcp", discovered.addr.to_string(), "").inspect_err(|e| {
|
||||
log::warn!("failed to parse locator from addr: {e}");
|
||||
})
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
let Some(Protocol::Tcp(port)) = ps.next() else {
|
||||
return None;
|
||||
};
|
||||
Some((ip, port))
|
||||
|
||||
runtime
|
||||
.connect_peer(&discovered.zid.into(), &[locator])
|
||||
.await;
|
||||
}
|
||||
})));
|
||||
Ok(Session { z, _jh })
|
||||
}
|
||||
|
||||
struct AbortOnDrop(JoinHandle<()>);
|
||||
impl Drop for AbortOnDrop {
|
||||
fn drop(&mut self) {
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
pub z: ZSession,
|
||||
_jh: Arc<AbortOnDrop>,
|
||||
}
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
//! Compat shim for the old libp2p code
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
use crate::{alias, discovery};
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use futures_lite::{Stream, StreamExt};
|
||||
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use futures_lite::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use zenoh::Result;
|
||||
use zenoh::Session;
|
||||
use zenoh::handlers::FifoChannelHandler;
|
||||
use zenoh::liveliness::LivelinessToken;
|
||||
use zenoh::pubsub::Publisher;
|
||||
use zenoh::pubsub::Subscriber;
|
||||
use zenoh::qos::CongestionControl;
|
||||
use zenoh::sample::Sample;
|
||||
use zenoh::sample::SampleKind;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
///
|
||||
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
|
||||
/// even be, and how to inject the right version into this config/initialization. E.g. should
|
||||
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
|
||||
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
// Uses oneshot senders to emulate function calling apis while avoiding requiring unique ownership
|
||||
// of the Swarm.
|
||||
#[derive(Debug)]
|
||||
pub enum ToSwarm {
|
||||
Unsubscribe {
|
||||
topic: String,
|
||||
@@ -26,52 +24,66 @@ pub enum ToSwarm {
|
||||
},
|
||||
Subscribe {
|
||||
topic: String,
|
||||
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
|
||||
result_sender: oneshot::Sender<Result<bool>>,
|
||||
},
|
||||
Publish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
|
||||
result_sender: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub enum FromSwarm {
|
||||
Message {
|
||||
from: PeerId,
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
},
|
||||
Discovered {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
Expired {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
Message { topic: String, data: Vec<u8> },
|
||||
Discovered {},
|
||||
Expired {},
|
||||
}
|
||||
|
||||
pub type Topics = HashMap<String, (Subscriber<()>, Publisher<'static>)>;
|
||||
pub struct Swarm {
|
||||
swarm: libp2p::Swarm<Behaviour>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
pub session: crate::Session,
|
||||
pub from_client: mpsc::Receiver<ToSwarm>,
|
||||
}
|
||||
|
||||
impl Swarm {
|
||||
pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = FromSwarm> + Send>> {
|
||||
let Swarm {
|
||||
mut swarm,
|
||||
session,
|
||||
mut from_client,
|
||||
} = self;
|
||||
let stream = async_stream::stream! {
|
||||
let mut session = session;
|
||||
let (mut to_topics, mut from_topics) = mpsc::channel(1024);
|
||||
let mut topics = Topics::new();
|
||||
let Ok((_token, discovery)) = register_liveness(&mut session.z).await else { return; };
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = from_client.recv() => {
|
||||
let Some(msg) = msg else { break };
|
||||
on_message(&mut swarm, msg);
|
||||
on_message(&mut session.z, &mut topics, &mut to_topics, msg).await;
|
||||
}
|
||||
event = swarm.next() => {
|
||||
let Some(event) = event else { break };
|
||||
if let Some(item) = filter_swarm_event(event) {
|
||||
yield item;
|
||||
event = from_topics.recv() => {
|
||||
if let Some(event) = event {
|
||||
yield event
|
||||
}
|
||||
}
|
||||
token = discovery.recv_async() => {
|
||||
if let Ok(token) = token {
|
||||
let key_expr = token.key_expr().as_str().to_owned();
|
||||
let nid = key_expr.strip_prefix("live/");
|
||||
yield match token.kind() {
|
||||
SampleKind::Put => {
|
||||
log::info!("discovered: {nid:?}");
|
||||
FromSwarm::Discovered {}
|
||||
}
|
||||
SampleKind::Delete => {
|
||||
log::info!("expired: {nid:?}");
|
||||
FromSwarm::Expired {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -79,208 +91,117 @@ impl Swarm {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
|
||||
match message {
|
||||
ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
async fn register_liveness(
|
||||
session: &mut Session,
|
||||
) -> Result<(LivelinessToken, Subscriber<FifoChannelHandler<Sample>>)> {
|
||||
let token = session
|
||||
.liveliness()
|
||||
.declare_token(format!("live/{}", session.zid()))
|
||||
.await?;
|
||||
let sub = session
|
||||
.liveliness()
|
||||
.declare_subscriber("live/*")
|
||||
.history(true)
|
||||
.await?;
|
||||
Ok((token, sub))
|
||||
}
|
||||
|
||||
async fn on_message(
|
||||
session: &mut Session,
|
||||
topics: &mut Topics,
|
||||
to_topics: &mut mpsc::Sender<FromSwarm>,
|
||||
msg: ToSwarm,
|
||||
) {
|
||||
match msg {
|
||||
ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.publish(gossipsub::IdentTopic::new(topic), data);
|
||||
_ = result_sender.send(result);
|
||||
let res = match topics.get(&topic) {
|
||||
Some(topic) => topic.1.put(data).await,
|
||||
None => {
|
||||
// TODO: this should be an error but the python FromSwarm is somewhat nondeterministic
|
||||
Ok(()) //Err("not subscribed to topic!".into()),
|
||||
}
|
||||
};
|
||||
_ = result_sender.send(res);
|
||||
}
|
||||
ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let Some((_, (subscriber, publisher))) = topics.remove_entry(&topic) else {
|
||||
_ = result_sender.send(false);
|
||||
return;
|
||||
};
|
||||
_ = publisher.undeclare().await;
|
||||
_ = subscriber.undeclare().await;
|
||||
_ = result_sender.send(true);
|
||||
}
|
||||
ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
assert!(topic.is_ascii());
|
||||
if topics.contains_key(&topic) {
|
||||
_ = result_sender.send(Ok(false));
|
||||
return;
|
||||
}
|
||||
|
||||
let publisher_res = session
|
||||
.declare_publisher(format!("topics/{topic}"))
|
||||
.congestion_control(CongestionControl::Block)
|
||||
.await;
|
||||
let publisher = match publisher_res {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
_ = result_sender.send(Err(e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let subscriber_res = session
|
||||
.declare_subscriber(format!("topics/{topic}"))
|
||||
.allowed_origin(zenoh::sample::Locality::Remote)
|
||||
.callback({
|
||||
let sender = to_topics.clone();
|
||||
let topic = topic.clone();
|
||||
move |sample| {
|
||||
if sample.kind() != SampleKind::Put {
|
||||
return;
|
||||
}
|
||||
_ = sender.try_send(FromSwarm::Message {
|
||||
topic: topic.clone(),
|
||||
data: sample.payload().to_bytes().to_vec(),
|
||||
});
|
||||
}
|
||||
})
|
||||
.await;
|
||||
let subscriber = match subscriber_res {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
_ = result_sender.send(Err(e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
assert!(topics.insert(topic, (subscriber, publisher)).is_none());
|
||||
_ = result_sender.send(Ok(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
|
||||
match event {
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
message:
|
||||
gossipsub::Message {
|
||||
source: Some(peer_id),
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => Some(FromSwarm::Message {
|
||||
from: peer_id,
|
||||
topic: topic.into_string(),
|
||||
data,
|
||||
}),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||
)) => Some(FromSwarm::Discovered { peer_id }),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
|
||||
peer_id,
|
||||
..
|
||||
})) => Some(FromSwarm::Expired { peer_id }),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create and configure a swarm.
|
||||
///
|
||||
/// - `listen_port`: TCP port to listen on. `0` lets the OS assign one.
|
||||
/// - `bootstrap_peers`: multiaddrs to dial for environments without mDNS.
|
||||
pub fn create_swarm(
|
||||
keypair: identity::Keypair,
|
||||
pub async fn create_swarm(
|
||||
identity: &str,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
bootstrap_peers: Vec<String>,
|
||||
listen_port: u16,
|
||||
) -> alias::AnyResult<Swarm> {
|
||||
let parsed_bootstrap_peers: Vec<libp2p::Multiaddr> = bootstrap_peers
|
||||
.iter()
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter_map(|s| s.parse().ok())
|
||||
.collect();
|
||||
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
.with_behaviour(|keypair| Behaviour::new(keypair, parsed_bootstrap_peers))?
|
||||
.build();
|
||||
|
||||
swarm.listen_on(format!("/ip4/0.0.0.0/tcp/{listen_port}").parse()?)?;
|
||||
Ok(Swarm { swarm, from_client })
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures_lite::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
use std::{env, sync::LazyLock};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
|
||||
let builder = Sha3_256::new().update(b"exo_discovery_network");
|
||||
|
||||
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
|
||||
let bytes = var.into_bytes();
|
||||
builder.update(&bytes)
|
||||
} else {
|
||||
builder.update(NETWORK_VERSION)
|
||||
}
|
||||
.finalize()
|
||||
});
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
/// also different-versioned instances of this same network.
|
||||
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
|
||||
async fn pnet_upgrade<TSocket>(
|
||||
socket: TSocket,
|
||||
_: impl Sized,
|
||||
) -> Result<PnetOutput<TSocket>, PnetError>
|
||||
where
|
||||
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
use pnet::{PnetConfig, PreSharedKey};
|
||||
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
|
||||
.handshake(socket)
|
||||
.await
|
||||
}
|
||||
|
||||
/// TCP/IP transport layer configuration.
|
||||
pub fn tcp_transport(
|
||||
keypair: &identity::Keypair,
|
||||
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
|
||||
use libp2p::{
|
||||
core::upgrade::Version,
|
||||
tcp::{Config, tokio},
|
||||
};
|
||||
|
||||
// `TCP_NODELAY` enabled => avoid latency
|
||||
let tcp_config = Config::default().nodelay(true);
|
||||
|
||||
// V1 + lazy flushing => 0-RTT negotiation
|
||||
let upgrade_version = Version::V1Lazy;
|
||||
|
||||
// Noise is faster than TLS + we don't care much for security
|
||||
let noise_config = noise::Config::new(keypair)?;
|
||||
|
||||
// Use default Yamux config for multiplexing
|
||||
let yamux_config = yamux::Config::default();
|
||||
|
||||
// Create new Tokio-driven TCP/IP transport layer
|
||||
let base_transport = tokio::Transport::new(tcp_config)
|
||||
.and_then(pnet_upgrade)
|
||||
.upgrade(upgrade_version)
|
||||
.authenticate(noise_config)
|
||||
.multiplex(yamux_config);
|
||||
|
||||
// Return boxed transport (to flatten complex type)
|
||||
Ok(base_transport.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(
|
||||
keypair: &identity::Keypair,
|
||||
bootstrap_peers: Vec<libp2p::Multiaddr>,
|
||||
) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair, bootstrap_peers)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
// => signed message authenticity + strict validation mode means the message-ID is
|
||||
// automatically provided by gossipsub w/out needing to provide custom message-ID function
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.max_transmit_size(8 * 1024 * 1024)
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
)
|
||||
.expect("creating gossipsub behavior should always work")
|
||||
}
|
||||
discovery_service_port: u16,
|
||||
) -> Result<Swarm> {
|
||||
let cfg = crate::cfg(identity, listen_port)?;
|
||||
let session = crate::open(cfg, listen_port, discovery_service_port).await?;
|
||||
Ok(Swarm {
|
||||
session,
|
||||
from_client,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
use futures_lite::StreamExt;
|
||||
use networking::swarm::{FromSwarm, create_swarm};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
|
||||
/// Helper: find a free TCP port.
|
||||
fn free_port() -> u16 {
|
||||
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
listener.local_addr().unwrap().port()
|
||||
}
|
||||
|
||||
/// Two nodes connect via bootstrap peers — no mDNS needed.
|
||||
///
|
||||
/// Node A listens on a fixed port. Node B bootstraps to A's address.
|
||||
/// We verify that B emits `FromSwarm::Discovered` for A's peer ID.
|
||||
#[tokio::test]
|
||||
async fn two_nodes_connect_via_bootstrap_peers() {
|
||||
let port_a = free_port();
|
||||
|
||||
// Node A: listens on a known port, no bootstrap peers
|
||||
let keypair_a = libp2p::identity::Keypair::generate_ed25519();
|
||||
let peer_id_a = keypair_a.public().to_peer_id();
|
||||
let (_tx_a, rx_a) = mpsc::channel(16);
|
||||
let swarm_a = create_swarm(keypair_a, rx_a, vec![], port_a).expect("create swarm A");
|
||||
let mut stream_a = swarm_a.into_stream();
|
||||
|
||||
// Node B: bootstraps to A's address
|
||||
let keypair_b = libp2p::identity::Keypair::generate_ed25519();
|
||||
let (_tx_b, rx_b) = mpsc::channel(16);
|
||||
let swarm_b = create_swarm(
|
||||
keypair_b,
|
||||
rx_b,
|
||||
vec![format!("/ip4/127.0.0.1/tcp/{port_a}")],
|
||||
0,
|
||||
)
|
||||
.expect("create swarm B");
|
||||
let mut stream_b = swarm_b.into_stream();
|
||||
|
||||
// Wait for B to discover A (connection established)
|
||||
let connected = timeout(Duration::from_secs(10), async {
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(event) = stream_a.next() => {
|
||||
// A will also see B connect, but we check from B's perspective
|
||||
let _ = event;
|
||||
}
|
||||
Some(event) = stream_b.next() => {
|
||||
if let FromSwarm::Discovered { peer_id } = event {
|
||||
if peer_id == peer_id_a {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
connected.is_ok() && connected.unwrap(),
|
||||
"Node B should discover Node A via bootstrap peer"
|
||||
);
|
||||
}
|
||||
|
||||
/// Empty bootstrap peers should work (backward compatible).
|
||||
#[tokio::test]
|
||||
async fn create_swarm_with_empty_bootstrap_peers() {
|
||||
let keypair = libp2p::identity::Keypair::generate_ed25519();
|
||||
let (_tx, rx) = mpsc::channel(16);
|
||||
let swarm = create_swarm(keypair, rx, vec![], 0);
|
||||
assert!(
|
||||
swarm.is_ok(),
|
||||
"create_swarm with no bootstrap peers should succeed"
|
||||
);
|
||||
}
|
||||
|
||||
/// Invalid multiaddr strings are silently filtered out.
|
||||
#[tokio::test]
|
||||
async fn create_swarm_ignores_invalid_bootstrap_addrs() {
|
||||
let keypair = libp2p::identity::Keypair::generate_ed25519();
|
||||
let (_tx, rx) = mpsc::channel(16);
|
||||
let swarm = create_swarm(
|
||||
keypair,
|
||||
rx,
|
||||
vec![
|
||||
"not-a-valid-multiaddr".to_string(),
|
||||
"".to_string(),
|
||||
"/ip4/10.0.0.1/tcp/30000".to_string(), // valid
|
||||
],
|
||||
0,
|
||||
);
|
||||
assert!(
|
||||
swarm.is_ok(),
|
||||
"create_swarm should succeed even with invalid bootstrap addrs"
|
||||
);
|
||||
}
|
||||
|
||||
/// Fixed listen port works correctly.
|
||||
#[tokio::test]
|
||||
async fn create_swarm_with_fixed_port() {
|
||||
let port = free_port();
|
||||
let keypair = libp2p::identity::Keypair::generate_ed25519();
|
||||
let (_tx, rx) = mpsc::channel(16);
|
||||
let swarm = create_swarm(keypair, rx, vec![], port);
|
||||
assert!(swarm.is_ok(), "create_swarm with fixed port should succeed");
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
// maybe this will hold test in the future...??
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn does_nothing() {}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
[package]
|
||||
name = "util"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "util"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
@@ -1 +0,0 @@
|
||||
pub mod wakerdeque;
|
||||
@@ -1,55 +0,0 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
|
||||
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
|
||||
pub struct WakerDeque<T> {
|
||||
waker: Option<Waker>,
|
||||
deque: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Debug for WakerDeque<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.deque.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WakerDeque<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
waker: None,
|
||||
deque: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, cx: &mut Context<'_>) {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
fn wake(&mut self) {
|
||||
let Some(ref mut w) = self.waker else { return };
|
||||
w.wake_by_ref();
|
||||
self.waker = None;
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_back()
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_front(value);
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_back(value);
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@ from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.master.main import Master
|
||||
from exo.routing.event_router import EventRouter
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
from exo.routing.router import Router
|
||||
from exo.shared.constants import EXO_DEFAULT_MODELS_DIR, EXO_LOG, EXO_PID_FILE
|
||||
from exo.shared.election import Election, ElectionResult
|
||||
from exo.shared.logging import logger_cleanup, logger_setup
|
||||
@@ -50,13 +50,11 @@ class Node:
|
||||
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> Self:
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
identity = os.urandom(16).hex().lstrip("0")
|
||||
node_id = NodeId(identity)
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(
|
||||
keypair,
|
||||
bootstrap_peers=args.bootstrap_peers,
|
||||
listen_port=args.libp2p_port,
|
||||
identity, listen_port=args.zenoh_port, discovery_service_port=52413
|
||||
)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
await router.register_topic(topics.LOCAL_EVENTS)
|
||||
@@ -338,16 +336,18 @@ def main_inner(args: "Args"):
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
|
||||
logger.info(f"{'=' * 40}")
|
||||
logger.info(f"Starting EXO | pid={os.getpid()}")
|
||||
logger.info(f"{'=' * 40}")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
logger.info(f"pid = {os.getpid()}")
|
||||
if os.getenv("EXO_LIBP2P_NAMESPACE"):
|
||||
raise ValueError(
|
||||
"EXO_LIBP2P_NAMESPACE has been removed - use EXO_ZENOH_NAMESPACE instead"
|
||||
)
|
||||
logger.info(f"EXO_ZENOH_NAMESPACE: {os.getenv('EXO_ZENOH_NAMESPACE')}")
|
||||
|
||||
if args.offline:
|
||||
logger.info("Running in OFFLINE mode — no internet checks, local models only")
|
||||
|
||||
if args.bootstrap_peers:
|
||||
logger.info(f"Bootstrap peers: {args.bootstrap_peers}")
|
||||
raise ValueError("Bootstrap peers has been temporarily removed")
|
||||
|
||||
if args.no_batch:
|
||||
os.environ["EXO_NO_BATCH"] = "1"
|
||||
@@ -387,7 +387,7 @@ class Args(FrozenModel):
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
legacy_daemon: bool = False
|
||||
bootstrap_peers: list[str] = []
|
||||
libp2p_port: int
|
||||
zenoh_port: int
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -460,11 +460,11 @@ class Args(FrozenModel):
|
||||
help="Comma-separated libp2p multiaddrs to dial on startup (env: EXO_BOOTSTRAP_PEERS)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--libp2p-port",
|
||||
"--zenoh-port",
|
||||
type=int,
|
||||
default=0,
|
||||
dest="libp2p_port",
|
||||
help="Fixed TCP port for libp2p to listen on (0 = OS-assigned).",
|
||||
default=52414,
|
||||
dest="zenoh_port",
|
||||
help="Fixed TCP port for zenoh to listen.",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
|
||||
@@ -464,7 +464,7 @@ class Master:
|
||||
)
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning("Error in command processor")
|
||||
|
||||
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.routing.router import get_node_zid
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.backends import Backend
|
||||
from exo.shared.types.commands import (
|
||||
@@ -16,7 +16,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import ModelId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
GlobalForwarderEvent,
|
||||
@@ -49,8 +49,7 @@ from exo.utils.info_gatherer.info_gatherer import NodeBackends
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
node_id = get_node_zid()
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[GlobalForwarderEvent]()
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
from exo_rs import FromSwarm
|
||||
from exo_rs import PyFromSwarm
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
|
||||
class ConnectionMessage(FrozenModel):
|
||||
node_id: NodeId
|
||||
connected: bool
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: FromSwarm.Connection) -> "ConnectionMessage":
|
||||
return cls(node_id=NodeId(update.peer_id), connected=update.connected)
|
||||
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
|
||||
return cls(connected=update.connected)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
import os
|
||||
from copy import copy
|
||||
from itertools import count
|
||||
from math import inf
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
@@ -13,17 +12,13 @@ from anyio import (
|
||||
sleep_forever,
|
||||
)
|
||||
from exo_rs import (
|
||||
AllQueuesFullError,
|
||||
FromSwarm,
|
||||
Keypair,
|
||||
MessageTooLargeError,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
PyFromSwarm,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.shared.constants import EXO_NODE_ZID
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
@@ -105,12 +100,12 @@ class Router:
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
identity: Keypair,
|
||||
bootstrap_peers: Sequence[str] = (),
|
||||
listen_port: int = 0,
|
||||
identity: str,
|
||||
listen_port: int,
|
||||
discovery_service_port: int,
|
||||
) -> "Router":
|
||||
return cls(
|
||||
handle=NetworkingHandle(identity, list(bootstrap_peers), listen_port)
|
||||
handle=NetworkingHandle.new(identity, listen_port, discovery_service_port)
|
||||
)
|
||||
|
||||
def __init__(self, handle: NetworkingHandle):
|
||||
@@ -191,10 +186,8 @@ class Router:
|
||||
from_swarm = await self._net.recv()
|
||||
logger.debug(from_swarm)
|
||||
match from_swarm:
|
||||
case FromSwarm.Message(origin, topic, data):
|
||||
logger.trace(
|
||||
f"Received message on {topic} from {origin} with payload {data}"
|
||||
)
|
||||
case PyFromSwarm.Message(topic, data):
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(
|
||||
f"Received message on unknown or inactive topic {topic}"
|
||||
@@ -202,7 +195,7 @@ class Router:
|
||||
continue
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
case FromSwarm.Connection():
|
||||
case PyFromSwarm.Connection():
|
||||
message = ConnectionMessage.from_update(from_swarm)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
@@ -225,33 +218,25 @@ class Router:
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
try:
|
||||
logger.trace(f"Sending message on {topic} with payload {data}")
|
||||
if len(data) > 1024 * 1024:
|
||||
logger.warning(
|
||||
"Sending overlarge payload, network performance may be temporarily degraded"
|
||||
)
|
||||
await self._net.gossipsub_publish(topic, data)
|
||||
except NoPeersSubscribedToTopicError:
|
||||
pass
|
||||
except AllQueuesFullError:
|
||||
logger.warning(f"All peer queues full, dropping message on {topic}")
|
||||
except MessageTooLargeError:
|
||||
logger.trace(f"Sending message on {topic} with payload {data}")
|
||||
if len(data) > 1024 * 1024:
|
||||
logger.warning(
|
||||
f"Message too large for gossipsub on {topic} ({len(data)} bytes), dropping"
|
||||
"Sending overlarge payload, network performance may be temporarily degraded"
|
||||
)
|
||||
await self._net.gossipsub_publish(topic, data)
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
path: str | bytes | PathLike[str] | PathLike[bytes] = EXO_NODE_ID_KEYPAIR,
|
||||
) -> Keypair:
|
||||
def get_node_zid(
|
||||
path: Path = EXO_NODE_ZID,
|
||||
) -> NodeId:
|
||||
"""
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate()
|
||||
return NodeId(os.urandom(16).hex())
|
||||
|
||||
"""
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
@@ -273,3 +258,4 @@ def get_node_id_keypair(
|
||||
keypair = Keypair.generate()
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
"""
|
||||
|
||||
@@ -76,7 +76,7 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
EXO_PID_FILE = EXO_CACHE_HOME / "exo.pid"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_NODE_ZID = EXO_CACHE_HOME / "node_zid"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
|
||||
@@ -46,7 +46,8 @@ class _InterceptHandler(logging.Handler):
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_rs").setLevel(logging.WARNING)
|
||||
logging.getLogger("exo_rs").setLevel(logging.INFO)
|
||||
logging.getLogger("networking").setLevel(logging.INFO)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@@ -327,7 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Send any connection message object; we close quickly to cancel before result creation
|
||||
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
|
||||
await cm_tx.send(ConnectionMessage(connected=True))
|
||||
|
||||
# Expect a broadcast for the new round at clock=1
|
||||
while True:
|
||||
|
||||
@@ -10,8 +10,8 @@ from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture, mark
|
||||
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.routing.router import get_node_zid
|
||||
from exo.shared.constants import EXO_NODE_ZID
|
||||
|
||||
NUM_CONCURRENT_PROCS = 10
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_bytes())
|
||||
queue.put(get_node_zid().encode())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
@@ -79,7 +79,7 @@ def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
# delete current file and write a new one
|
||||
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
||||
_delete_if_exists(EXO_NODE_ZID)
|
||||
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
|
||||
with caplog.at_level(101): # supress logs
|
||||
@@ -88,6 +88,6 @@ def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
|
||||
# make sure that after deleting, we are not fetching the same value
|
||||
_delete_if_exists(EXO_NODE_ID_KEYPAIR)
|
||||
_delete_if_exists(EXO_NODE_ZID)
|
||||
for _ in range(reps):
|
||||
assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
|
||||
|
||||
@@ -97,13 +97,6 @@ def test_macos_uses_traditional_paths():
|
||||
assert home / ".exo" == constants.EXO_CACHE_HOME
|
||||
|
||||
|
||||
def test_node_id_in_config_dir():
|
||||
"""Test that node ID keypair is in the config directory."""
|
||||
import exo.shared.constants as constants
|
||||
|
||||
assert constants.EXO_NODE_ID_KEYPAIR.parent == constants.EXO_CONFIG_HOME
|
||||
|
||||
|
||||
def test_models_in_data_dir():
|
||||
"""Test that default models directory is in the data directory."""
|
||||
# Clear EXO_MODELS_DIRS to test default behavior
|
||||
|
||||
@@ -38,7 +38,7 @@ def print_startup_banner(port: int) -> None:
|
||||
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
║ ║
|
||||
║ 🌐 Dashboard & API Ready ║
|
||||
║ Dashboard & API Ready ║
|
||||
║ ║
|
||||
║ {dashboard_url}{" " * (69 - len(dashboard_url))}║
|
||||
║ ║
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
import socket
|
||||
from typing import Literal
|
||||
|
||||
import anyio
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from hypercorn import Config
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.constants import EXO_DEFAULT_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
class Tests(BaseModel):
|
||||
# list[hostname, ip addr]
|
||||
devs: list[list[str]]
|
||||
ibv_devs: list[list[str | None]] | None
|
||||
model_id: ModelId
|
||||
kind: Literal["ring", "jaccl", "both"]
|
||||
|
||||
|
||||
iid = InstanceId("im testing here")
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52414"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = "-"
|
||||
cfg.errorlog = "-"
|
||||
ev = anyio.Event()
|
||||
app = FastAPI()
|
||||
app.post("/run_test")(run_test)
|
||||
app.post("/kill")(lambda: kill(ev))
|
||||
app.get("/tb_detection")(tb_detection)
|
||||
app.get("/models")(list_models)
|
||||
await serve(
|
||||
app, # type: ignore
|
||||
cfg,
|
||||
shutdown_trigger=lambda: ev.wait(),
|
||||
)
|
||||
|
||||
|
||||
def kill(ev: anyio.Event):
|
||||
ev.set()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
async def tb_detection():
|
||||
send, recv = channel[GatheredInfo]()
|
||||
ig = InfoGatherer(send)
|
||||
with anyio.move_on_after(1):
|
||||
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
|
||||
with recv:
|
||||
return recv.collect()
|
||||
|
||||
|
||||
def list_models():
|
||||
sent = set[str]()
|
||||
for path in EXO_DEFAULT_MODELS_DIR.rglob("model-*.safetensors"):
|
||||
if "--" not in path.parent.name:
|
||||
continue
|
||||
name = path.parent.name.replace("--", "/")
|
||||
if name in sent:
|
||||
continue
|
||||
sent.add(name)
|
||||
yield ModelId(path.parent.name.replace("--", "/"))
|
||||
|
||||
|
||||
async def run_test(test: Tests):
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
|
||||
async def run():
|
||||
logger.info(f"testing {test.model_id}")
|
||||
|
||||
instances: list[Instance] = []
|
||||
if test.kind in ["ring", "both"]:
|
||||
i = await ring_instance(test, hn)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
instances.append(i)
|
||||
if test.kind in ["jaccl", "both"]:
|
||||
i = await jaccl_instance(test)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
instances.append(i)
|
||||
|
||||
for instance in instances:
|
||||
recv = await execute_test(test, instance, hn)
|
||||
|
||||
str_out = ""
|
||||
|
||||
for item in recv:
|
||||
if isinstance(item, ChunkGenerated):
|
||||
assert isinstance(item.chunk, TokenChunk)
|
||||
str_out += item.chunk.text
|
||||
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
yield str_out + "\n"
|
||||
yield item.model_dump_json() + "\n"
|
||||
|
||||
return StreamingResponse(run())
|
||||
|
||||
|
||||
async def ring_instance(test: Tests, hn: str) -> Instance | None:
|
||||
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
if test.devs[i][0] == hn:
|
||||
hn = test.devs[i][0]
|
||||
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
|
||||
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52417)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = await ModelCard.load(test.model_id)
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52417,
|
||||
hosts_by_node={NodeId(hn): hbn},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=test.model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
model_card=card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=(card.n_layers // world_size) * i,
|
||||
end_layer=min(
|
||||
card.n_layers, (card.n_layers // world_size) * (i + 1)
|
||||
),
|
||||
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
|
||||
- (card.n_layers // world_size) * i,
|
||||
)
|
||||
for i in range(world_size)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
||||
world_size = len(test.devs)
|
||||
commands: list[Task] = [
|
||||
(LoadModel(instance_id=iid)),
|
||||
(StartWarmup(instance_id=iid)),
|
||||
(
|
||||
TextGeneration(
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=test.model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
)
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
),
|
||||
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
|
||||
]
|
||||
if world_size > 1:
|
||||
commands.insert(0, ConnectToGroup(instance_id=iid))
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
)
|
||||
ev_send, _ev_recv = mp_channel[Event]()
|
||||
task_send, task_recv = mp_channel[Task]()
|
||||
|
||||
for command in commands:
|
||||
task_send.send(command)
|
||||
|
||||
entrypoint(
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
logger,
|
||||
)
|
||||
|
||||
# TODO(evan): return ev_recv.collect()
|
||||
return []
|
||||
|
||||
|
||||
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
||||
card = await ModelCard.load(test.model_id)
|
||||
world_size = len(test.devs)
|
||||
assert test.ibv_devs
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
jaccl_devices=test.ibv_devs,
|
||||
# rank 0 is always coordinator
|
||||
jaccl_coordinators={
|
||||
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
|
||||
},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=test.model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(host[0]): TensorShardMetadata(
|
||||
model_card=card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
end_layer=card.n_layers,
|
||||
n_layers=card.n_layers,
|
||||
)
|
||||
for i, host in enumerate(test.devs)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main)
|
||||
@@ -1,85 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import itertools
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, cast
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
if not (args := sys.argv[1:]):
|
||||
sys.exit(
|
||||
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be jaccl or ring"
|
||||
)
|
||||
|
||||
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
|
||||
hosts = args[1:] if kind != "both" else args
|
||||
ts = subprocess.run(
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
|
||||
ips = [ip[h] for h in hosts]
|
||||
devs = [[h, ip[h]] for h in hosts]
|
||||
n = len(hosts)
|
||||
|
||||
|
||||
def get_tb(a: str) -> list[dict[str, Any]]:
|
||||
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
|
||||
return json.loads(r.read()) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def get_models(a: str) -> set[str]:
|
||||
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
|
||||
return set(json.loads(r.read())) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def run(h: str, a: str, body: bytes) -> None:
|
||||
with urlopen(
|
||||
Request(
|
||||
f"http://{a}:52414/run_test",
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
),
|
||||
timeout=300,
|
||||
) as r: # pyright: ignore[reportAny]
|
||||
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
|
||||
print(f"\n{h}@{a}: {line}", flush=True)
|
||||
|
||||
|
||||
with ThreadPoolExecutor(n) as exctr:
|
||||
if kind in ("jaccl", "both"):
|
||||
payloads = list(exctr.map(get_tb, ips))
|
||||
|
||||
u2e = {
|
||||
ident["domainUuid"]: (i, ident["rdmaInterface"])
|
||||
for i, p in enumerate(payloads)
|
||||
for d in p
|
||||
for ident in cast(
|
||||
list[dict[str, str]],
|
||||
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
|
||||
)
|
||||
}
|
||||
edges = {
|
||||
(u2e[s][0], u2e[t][0]): u2e[t][1]
|
||||
for p in payloads
|
||||
for d in p
|
||||
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
|
||||
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
|
||||
}
|
||||
ibv_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
|
||||
else:
|
||||
ibv_devs = None
|
||||
|
||||
models = set[str].intersection(*exctr.map(get_models, ips))
|
||||
|
||||
print("\n")
|
||||
print("=" * 70)
|
||||
print(f"Starting test with {models}")
|
||||
print("=" * 70)
|
||||
print("\n")
|
||||
for model in models:
|
||||
body = json.dumps(
|
||||
{"devs": devs, "model_id": model, "ibv_devs": ibv_devs, "kind": kind}
|
||||
).encode()
|
||||
list(exctr.map(run, hosts, ips, itertools.repeat(body)))
|
||||
@@ -42,7 +42,7 @@ i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run $remote_installable" 2>&1 |
|
||||
"EXO_ZENOH_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run $remote_installable" 2>&1 |
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user