mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-19 12:15:07 -04:00
big refactor
Fix. Everything. Co-authored-by: Andrei Cravtov <the.andrei.cravtov@gmail.com> Co-authored-by: Matt Beton <matthew.beton@gmail.com> Co-authored-by: Alex Cheema <alexcheema123@gmail.com> Co-authored-by: Seth Howes <sethshowes@gmail.com>
This commit is contained in:
@@ -1,71 +0,0 @@
|
||||
# Configures the Golang support and builds the forwarder
|
||||
# TODO: split this up in the future as this is unrelated tasks??
|
||||
|
||||
# Top-level parameters that are bound to the provider flake
|
||||
# These are passed from `flake.nix` using importApply
|
||||
{
|
||||
localSelf,
|
||||
flake-parts-lib,
|
||||
nixpkgs-lib,
|
||||
...
|
||||
}:
|
||||
|
||||
# These values would bind to the consumer flake when this flake module is imported:
|
||||
{
|
||||
config,
|
||||
self,
|
||||
inputs,
|
||||
getSystem,
|
||||
moduleWithSystem,
|
||||
withSystem,
|
||||
...
|
||||
}:
|
||||
|
||||
# The actual flake-parts module configuration
|
||||
{
|
||||
perSystem =
|
||||
{
|
||||
config,
|
||||
self',
|
||||
inputs',
|
||||
pkgs,
|
||||
system,
|
||||
...
|
||||
}:
|
||||
let
|
||||
# flakeRoot = nixpkgs-lib.getExe config.flake-root.package;
|
||||
#
|
||||
# # Build the networking/forwarder Go utility.
|
||||
# forwarder = pkgs.buildGoModule {
|
||||
# pname = "exo-forwarder";
|
||||
# version = "0.1.0";
|
||||
# src = "${flakeRoot}/networking/forwarder";
|
||||
#
|
||||
# vendorHash = "sha256-BXIGg2QYqHDz2TNe8hLAGC6jVlffp9766H+WdkkuVgA=";
|
||||
#
|
||||
# # Only the main package at the repository root needs building.
|
||||
# subPackages = [ "." ];
|
||||
# };
|
||||
in
|
||||
{
|
||||
packages = {
|
||||
# inherit forwarder;
|
||||
};
|
||||
|
||||
apps = {
|
||||
# forwarder = {
|
||||
# type = "app";
|
||||
# program = "${forwarder}/bin/forwarder";
|
||||
# };
|
||||
};
|
||||
|
||||
make-shells.default = {
|
||||
# Go 1.24 compiler – align with go.mod
|
||||
packages = [ pkgs.go_1_24 ];
|
||||
shellHook = ''
|
||||
GOPATH="''$(${nixpkgs-lib.getExe config.flake-root.package})"/.go_cache
|
||||
export GOPATH
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
# Provides pretty banner & command index for this flake
|
||||
|
||||
# Top-level parameters that are bound to the provider flake
|
||||
# These are passed from `flake.nix` using importApply
|
||||
{
|
||||
localSelf,
|
||||
flake-parts-lib,
|
||||
nixpkgs-lib,
|
||||
just-flake,
|
||||
...
|
||||
}:
|
||||
|
||||
# These values would bind to the consumer flake when this flake module is imported:
|
||||
{
|
||||
config,
|
||||
self,
|
||||
inputs,
|
||||
getSystem,
|
||||
moduleWithSystem,
|
||||
withSystem,
|
||||
...
|
||||
}:
|
||||
|
||||
# The actual flake-parts module configuration
|
||||
{
|
||||
imports = [ just-flake.flakeModule ];
|
||||
perSystem =
|
||||
{
|
||||
config,
|
||||
self',
|
||||
inputs',
|
||||
pkgs,
|
||||
system,
|
||||
...
|
||||
}:
|
||||
{
|
||||
just-flake.features = {
|
||||
# treefmt.enable = true;
|
||||
# rust.enable = true;
|
||||
# convco.enable = true;
|
||||
# hello = {
|
||||
# enable = true;
|
||||
# justfile = ''
|
||||
# hello:
|
||||
# echo Hello World
|
||||
# '';
|
||||
# };
|
||||
};
|
||||
|
||||
make-shells.default = {
|
||||
inputsFrom = [ config.just-flake.outputs.devShell ];
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
# Provides macmon binary for the worker.
|
||||
|
||||
# These values would bind to the consumer flake when this flake module is imported:
|
||||
{
|
||||
config,
|
||||
self,
|
||||
inputs,
|
||||
getSystem,
|
||||
moduleWithSystem,
|
||||
withSystem,
|
||||
...
|
||||
}:
|
||||
|
||||
# The actual flake-parts module configuration
|
||||
{
|
||||
perSystem =
|
||||
{
|
||||
config,
|
||||
self',
|
||||
inputs',
|
||||
pkgs,
|
||||
system,
|
||||
...
|
||||
}:
|
||||
{
|
||||
make-shells.default = {
|
||||
packages = if (system == "aarch64-darwin") then ([ pkgs.macmon ]) else ([]);
|
||||
};
|
||||
};
|
||||
}
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -12,8 +12,6 @@ hosts*.json
|
||||
# TODO figure out how to properly solve the issue with these target directories showing up
|
||||
networking/target/
|
||||
networking/topology/target/
|
||||
rust/target/
|
||||
rust/Cargo.lock
|
||||
|
||||
build/
|
||||
dist/
|
||||
@@ -26,4 +24,4 @@ dist/
|
||||
just-flake.just
|
||||
|
||||
# for the gitingest enthusiasts
|
||||
digest.txt
|
||||
digest.txt
|
||||
|
||||
8
.idea/exo-v2.iml
generated
8
.idea/exo-v2.iml
generated
@@ -10,11 +10,19 @@
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/scripts/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/exo_pyo3_bindings/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/exo_pyo3_bindings/tests" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/util/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/networking/examples" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/networking/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/networking/tests" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/rust/system_custodian/src" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.direnv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/build" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/dist" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.go_cache" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/rust/target" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.13 (exo)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
|
||||
2
.idea/vcs.xml
generated
2
.idea/vcs.xml
generated
@@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
133
copy_model.sh
Executable file
133
copy_model.sh
Executable file
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# copy_model.sh: clone ~/.exo/models from SOURCE to one or more TARGETS using scp -3.
|
||||
# Username defaults:
|
||||
# - If host is "aN" and no user given, username defaults to "aN".
|
||||
# - Otherwise defaults to $(whoami), unless you pass user@host.
|
||||
#
|
||||
# Examples:
|
||||
# ./copy_model.sh a1 a2 a3
|
||||
# ./copy_model.sh a1 frank@a2 192.168.1.3
|
||||
|
||||
if [ $# -lt 2 ]; then
|
||||
echo "Usage: $0 SOURCE TARGET [TARGET...]" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
SOURCE="$1"
|
||||
shift
|
||||
TARGETS=("$@")
|
||||
|
||||
DEFAULT_USER="$(whoami)"
|
||||
MODELS_REL=".exo/models" # relative under $HOME
|
||||
|
||||
timestamp() { date "+%Y-%m-%d %H:%M:%S"; }
|
||||
|
||||
split_user_host() {
|
||||
local in="$1"
|
||||
if [[ "$in" == *"@"* ]]; then
|
||||
printf "%s|%s" "${in%%@*}" "${in#*@}"
|
||||
else
|
||||
printf "|%s" "$in"
|
||||
fi
|
||||
}
|
||||
|
||||
resolve_ip() {
|
||||
local hostish="$1"
|
||||
if [[ "$hostish" =~ ^a([0-9]+)$ ]]; then
|
||||
echo "192.168.1.${BASH_REMATCH[1]}"
|
||||
else
|
||||
echo "$hostish"
|
||||
fi
|
||||
}
|
||||
|
||||
default_user_for() {
|
||||
local hostish="$1"
|
||||
if [[ "$hostish" =~ ^a([0-9]+)$ ]]; then
|
||||
echo "$hostish"
|
||||
else
|
||||
echo "$DEFAULT_USER"
|
||||
fi
|
||||
}
|
||||
|
||||
SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR -o ConnectTimeout=10)
|
||||
SSHPASS_BIN="$(command -v sshpass || true)"
|
||||
SCP_BIN="${SCP_BIN:-scp}"
|
||||
|
||||
read -s -p "Password for all hosts: " PASS
|
||||
echo
|
||||
if [ -n "$SSHPASS_BIN" ]; then
|
||||
echo "$(timestamp) sshpass found: will provide the password non-interactively."
|
||||
else
|
||||
echo "$(timestamp) WARNING: sshpass not found — you’ll be prompted by scp/ssh per hop unless keys are set up."
|
||||
fi
|
||||
|
||||
# Build source endpoint (default username logic)
|
||||
IFS='|' read -r SRC_USER_RAW SRC_HOSTISH <<<"$(split_user_host "$SOURCE")"
|
||||
SRC_USER="${SRC_USER_RAW:-$(default_user_for "$SRC_HOSTISH")}"
|
||||
SRC_IP="$(resolve_ip "$SRC_HOSTISH")"
|
||||
SRC_HOST="${SRC_USER}@${SRC_IP}"
|
||||
|
||||
echo "$(timestamp) Source: ${SRC_HOST}:~/${MODELS_REL}"
|
||||
echo "$(timestamp) Targets: ${#TARGETS[@]}"
|
||||
|
||||
# Helper to run a simple remote command via ssh (for mkdir -p checks)
|
||||
ssh_run() {
|
||||
local host="$1"
|
||||
shift
|
||||
if [ -n "$SSHPASS_BIN" ]; then
|
||||
sshpass -p "$PASS" ssh "${SSH_OPTS[@]}" "$host" "$@"
|
||||
else
|
||||
ssh "${SSH_OPTS[@]}" "$host" "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
# Ensure source dir exists (create if missing, per your request)
|
||||
ssh_run "$SRC_HOST" "mkdir -p ~/${MODELS_REL}"
|
||||
|
||||
failures=0
|
||||
count=0
|
||||
for T in "${TARGETS[@]}"; do
|
||||
count=$((count + 1))
|
||||
IFS='|' read -r T_USER_RAW T_HOSTISH <<<"$(split_user_host "$T")"
|
||||
T_USER="${T_USER_RAW:-$(default_user_for "$T_HOSTISH")}"
|
||||
T_IP="$(resolve_ip "$T_HOSTISH")"
|
||||
T_HOST="${T_USER}@${T_IP}"
|
||||
|
||||
echo "============================================================"
|
||||
echo "$(timestamp) [${count}/${#TARGETS[@]}] ${SRC_HOST} ==> ${T_HOST}"
|
||||
echo "$(timestamp) Ensuring destination directory exists…"
|
||||
ssh_run "$T_HOST" "mkdir -p ~/${MODELS_REL%/*}" # ~/.exo
|
||||
|
||||
# Copy the whole "models" directory into ~/.exo on the target.
|
||||
# scp -3 = copy between two remotes via local; -r recursive; -p preserve times/modes
|
||||
if [ -n "$SSHPASS_BIN" ]; then
|
||||
echo "$(timestamp) Running: scp -3 -rp ${SRC_HOST}:~/${MODELS_REL} ${T_HOST}:~/.exo/"
|
||||
if sshpass -p "$PASS" "$SCP_BIN" "${SSH_OPTS[@]}" -3 -rp \
|
||||
"${SRC_HOST}:~/${MODELS_REL}" \
|
||||
"${T_HOST}:~/.exo/"; then
|
||||
echo "$(timestamp) [${count}] Done: ${T_HOST}"
|
||||
else
|
||||
echo "$(timestamp) [${count}] ERROR during scp to ${T_HOST}" >&2
|
||||
failures=$((failures + 1))
|
||||
fi
|
||||
else
|
||||
echo "$(timestamp) Running: scp -3 -rp ${SRC_HOST}:~/${MODELS_REL} ${T_HOST}:~/.exo/"
|
||||
if "$SCP_BIN" "${SSH_OPTS[@]}" -3 -rp \
|
||||
"${SRC_HOST}:~/${MODELS_REL}" \
|
||||
"${T_HOST}:~/.exo/"; then
|
||||
echo "$(timestamp) [${count}] Done: ${T_HOST}"
|
||||
else
|
||||
echo "$(timestamp) [${count}] ERROR during scp to ${T_HOST}" >&2
|
||||
failures=$((failures + 1))
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
echo "============================================================"
|
||||
if [ "$failures" -eq 0 ]; then
|
||||
echo "$(timestamp) All transfers completed successfully."
|
||||
else
|
||||
echo "$(timestamp) Completed with ${failures} failure(s)."
|
||||
fi
|
||||
@@ -943,7 +943,7 @@
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
showLaunchStatus(`Instance launched successfully: ${result.instance_id}`, 'success');
|
||||
showLaunchStatus('Instance launched successfully');
|
||||
|
||||
// Reset form
|
||||
modelSelect.value = '';
|
||||
|
||||
39
flake.lock
generated
39
flake.lock
generated
@@ -1,5 +1,26 @@
|
||||
{
|
||||
"nodes": {
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1755585599,
|
||||
"narHash": "sha256-tl/0cnsqB/Yt7DbaGMel2RLa7QG5elA8lkaOXli6VdY=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "6ed03ef4c8ec36d193c18e06b9ecddde78fb7e42",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
@@ -102,12 +123,30 @@
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"fenix": "fenix",
|
||||
"flake-parts": "flake-parts",
|
||||
"flake-root": "flake-root",
|
||||
"just-flake": "just-flake",
|
||||
"make-shell": "make-shell",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1755504847,
|
||||
"narHash": "sha256-VX0B9hwhJypCGqncVVLC+SmeMVd/GAYbJZ0MiiUn2Pk=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "a905e3b21b144d77e1b304e49f3264f6f8d4db75",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "rust-lang",
|
||||
"ref": "nightly",
|
||||
"repo": "rust-analyzer",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
92
flake.nix
92
flake.nix
@@ -20,47 +20,39 @@
|
||||
|
||||
# Provides flake integration with [Just](https://just.systems/man/en/)
|
||||
just-flake.url = "github:juspay/just-flake";
|
||||
|
||||
# Provides Rust dev-env integration:
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
# TODO: figure out caching story
|
||||
# nixConfig = {
|
||||
# # nix community cachix
|
||||
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
|
||||
# extra-substituters = "https://nix-community.cachix.org";
|
||||
# };
|
||||
|
||||
outputs =
|
||||
inputs@{
|
||||
flake-parts,
|
||||
...
|
||||
}:
|
||||
flake-parts.lib.mkFlake { inherit inputs; } (
|
||||
{
|
||||
flake-parts-lib,
|
||||
self,
|
||||
...
|
||||
}:
|
||||
let
|
||||
nixpkgs-lib = inputs.nixpkgs.lib;
|
||||
|
||||
# A wraper around importApply that supplies default parameters
|
||||
importApply' =
|
||||
path: extraParams:
|
||||
(flake-parts-lib.importApply path (
|
||||
nixpkgs-lib.recursiveUpdate {
|
||||
localSelf = self;
|
||||
inherit flake-parts-lib;
|
||||
inherit nixpkgs-lib;
|
||||
} extraParams
|
||||
));
|
||||
|
||||
# instantiate all the flake modules, passing custom arguments to them as needed
|
||||
flakeModules = {
|
||||
flakeRoot = importApply' ./.flake-modules/flake-root.nix { inherit (inputs) flake-root; };
|
||||
justFlake = importApply' ./.flake-modules/just-flake.nix { inherit (inputs) just-flake; };
|
||||
goForwarder = importApply' ./.flake-modules/go-forwarder.nix { };
|
||||
};
|
||||
in
|
||||
{ flake-parts-lib, self, ... }:
|
||||
{
|
||||
imports = [
|
||||
inputs.make-shell.flakeModules.default
|
||||
flakeModules.flakeRoot
|
||||
flakeModules.justFlake
|
||||
flakeModules.goForwarder
|
||||
./.flake-modules/macmon.nix
|
||||
|
||||
./nix/modules/pkgs-init.nix # nixpkgs overlays manager
|
||||
./nix/modules/flake-root.nix
|
||||
./nix/modules/just-flake.nix
|
||||
./nix/modules/macmon.nix
|
||||
./nix/modules/python.nix
|
||||
./nix/modules/rust.nix
|
||||
./nix/modules/go-forwarder.nix
|
||||
];
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
@@ -75,55 +67,31 @@
|
||||
system,
|
||||
...
|
||||
}:
|
||||
let
|
||||
buildInputs = with pkgs; [
|
||||
];
|
||||
nativeBuildInputs = with pkgs; [
|
||||
];
|
||||
in
|
||||
{
|
||||
# Per-system attributes can be defined here. The self' and inputs'
|
||||
# module parameters provide easy access to attributes of the same
|
||||
# system.
|
||||
# NOTE: pkgs is equivalent to inputs'.nixpkgs.legacyPackages.hello;
|
||||
apps = {
|
||||
python-lsp = {
|
||||
type = "app";
|
||||
program = "${pkgs.basedpyright}/bin/basedpyright-langserver";
|
||||
};
|
||||
default = self'.apps.forwarder;
|
||||
};
|
||||
apps = { };
|
||||
|
||||
make-shells.default = {
|
||||
packages = [
|
||||
pkgs.python313
|
||||
pkgs.uv
|
||||
pkgs.protobuf
|
||||
pkgs.basedpyright
|
||||
pkgs.ruff
|
||||
];
|
||||
|
||||
nativeBuildInputs =
|
||||
with pkgs;
|
||||
[
|
||||
nixpkgs-fmt
|
||||
cmake
|
||||
]
|
||||
++ buildInputs
|
||||
++ nativeBuildInputs;
|
||||
|
||||
# Arguments which are intended to be environment variables in the shell environment
|
||||
# should be changed to attributes of the `env` option
|
||||
env = {
|
||||
# fixes libstdc++.so issues and libgl.so issues
|
||||
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib";
|
||||
};
|
||||
nativeBuildInputs = with pkgs; [
|
||||
nixpkgs-fmt
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
export GO_BUILD_DIR=$(git rev-parse --show-toplevel)/build;
|
||||
export DASHBOARD_DIR=$(git rev-parse --show-toplevel)/dashboard;
|
||||
'';
|
||||
|
||||
# Arguments which are intended to be environment variables in the shell environment
|
||||
# should be changed to attributes of the `env` option
|
||||
env = { };
|
||||
|
||||
# Arbitrary mkDerivation arguments should be changed to be attributes of the `additionalArguments` option
|
||||
additionalArguments = { };
|
||||
};
|
||||
|
||||
65
kill_remote.sh
Executable file
65
kill_remote.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
###############################################################################
|
||||
# Args & prerequisites
|
||||
###############################################################################
|
||||
if [[ $# -gt 1 ]]; then
|
||||
echo "Usage: $0 [hosts_file]" >&2
|
||||
exit 1
|
||||
fi
|
||||
HOSTS_FILE=${1:-hosts.txt}
|
||||
|
||||
###############################################################################
|
||||
# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+)
|
||||
###############################################################################
|
||||
if [[ ! -f "$HOSTS_FILE" ]]; then
|
||||
echo "Error: $HOSTS_FILE not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if builtin command -v mapfile >/dev/null 2>&1; then
|
||||
mapfile -t HOSTS <"$HOSTS_FILE"
|
||||
else
|
||||
HOSTS=()
|
||||
while IFS= read -r h; do
|
||||
[[ -n "$h" ]] && HOSTS+=("$h")
|
||||
done <"$HOSTS_FILE"
|
||||
fi
|
||||
[[ ${#HOSTS[@]} -gt 0 ]] || {
|
||||
echo "No hosts found in $HOSTS_FILE"
|
||||
exit 1
|
||||
}
|
||||
|
||||
###############################################################################
|
||||
# Helper – run a remote command and capture rc/stderr/stdout
|
||||
###############################################################################
|
||||
ssh_opts=(-o StrictHostKeyChecking=no
|
||||
-o LogLevel=ERROR)
|
||||
|
||||
run_remote() { # $1 host $2 command
|
||||
local host=$1 cmd=$2 rc
|
||||
if ssh "${ssh_opts[@]}" "$host" "$cmd"; then
|
||||
rc=0
|
||||
else
|
||||
rc=$?
|
||||
fi
|
||||
return $rc
|
||||
}
|
||||
|
||||
###############################################################################
|
||||
# Kill exo everywhere (parallel)
|
||||
###############################################################################
|
||||
echo "=== Killing exo on ${#HOSTS[@]} host(s) ==="
|
||||
fail=0
|
||||
for h in "${HOSTS[@]}"; do
|
||||
(
|
||||
run_remote "$h" 'pkill -f exo || true'
|
||||
) || fail=1 &
|
||||
done
|
||||
wait
|
||||
((fail == 0)) || {
|
||||
echo "❌ Some hosts could not be reached—check SSH access."
|
||||
exit 1
|
||||
}
|
||||
echo "✓ exo processes killed on all reachable hosts."
|
||||
@@ -2,39 +2,14 @@
|
||||
# 1. ${lib.getExe config.flake-root.package}
|
||||
# 2. $FLAKE_ROOT environment-varible
|
||||
|
||||
# Top-level parameters that are bound to the provider flake
|
||||
# These are passed from `flake.nix` using importApply
|
||||
{
|
||||
localSelf,
|
||||
flake-parts-lib,
|
||||
nixpkgs-lib,
|
||||
flake-root,
|
||||
...
|
||||
}:
|
||||
|
||||
# These values would bind to the consumer flake when this flake module is imported:
|
||||
{
|
||||
config,
|
||||
self,
|
||||
inputs,
|
||||
getSystem,
|
||||
moduleWithSystem,
|
||||
withSystem,
|
||||
...
|
||||
}:
|
||||
{ inputs, ... }:
|
||||
|
||||
# The actual flake-parts module configuration
|
||||
{
|
||||
imports = [ flake-root.flakeModule ];
|
||||
imports = [ inputs.flake-root.flakeModule ];
|
||||
perSystem =
|
||||
{
|
||||
config,
|
||||
self',
|
||||
inputs',
|
||||
pkgs,
|
||||
system,
|
||||
...
|
||||
}:
|
||||
{ config, ... }:
|
||||
{
|
||||
flake-root.projectRootFile = "flake.nix"; # Not necessary, as flake.nix is the default
|
||||
|
||||
19
nix/modules/go-forwarder.nix
Normal file
19
nix/modules/go-forwarder.nix
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
perSystem =
|
||||
{
|
||||
config,
|
||||
pkgs,
|
||||
lib,
|
||||
...
|
||||
}:
|
||||
{
|
||||
make-shells.default = {
|
||||
# Go 1.24 compiler – align with go.mod
|
||||
packages = [ pkgs.go_1_24 ];
|
||||
shellHook = ''
|
||||
GOPATH="''$(${lib.getExe config.flake-root.package})"/.go_cache
|
||||
export GOPATH
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
26
nix/modules/just-flake.nix
Normal file
26
nix/modules/just-flake.nix
Normal file
@@ -0,0 +1,26 @@
|
||||
# Provides pretty banner & command index for this flake
|
||||
|
||||
{ inputs, ... }:
|
||||
{
|
||||
imports = [ inputs.just-flake.flakeModule ];
|
||||
perSystem =
|
||||
{ config, ... }:
|
||||
{
|
||||
just-flake.features = {
|
||||
# treefmt.enable = true;
|
||||
# rust.enable = true;
|
||||
# convco.enable = true;
|
||||
# hello = {
|
||||
# enable = true;
|
||||
# justfile = ''
|
||||
# hello:
|
||||
# echo Hello World
|
||||
# '';
|
||||
# };
|
||||
};
|
||||
|
||||
make-shells.default = {
|
||||
inputsFrom = [ config.just-flake.outputs.devShell ];
|
||||
};
|
||||
};
|
||||
}
|
||||
12
nix/modules/macmon.nix
Normal file
12
nix/modules/macmon.nix
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
perSystem =
|
||||
{ lib, pkgs, ... }:
|
||||
lib.mkMerge [
|
||||
(lib.mkIf pkgs.stdenv.isDarwin {
|
||||
make-shells.default = {
|
||||
packages = [ pkgs.macmon ];
|
||||
};
|
||||
})
|
||||
];
|
||||
|
||||
}
|
||||
62
nix/modules/pkgs-init.nix
Normal file
62
nix/modules/pkgs-init.nix
Normal file
@@ -0,0 +1,62 @@
|
||||
# Single module responsible for collecting all overlays and instantiating in one go
|
||||
|
||||
{
|
||||
flake-parts-lib,
|
||||
inputs,
|
||||
self,
|
||||
specialArgs,
|
||||
...
|
||||
}:
|
||||
let
|
||||
inherit (flake-parts-lib) mkPerSystemOption;
|
||||
in
|
||||
{
|
||||
options.perSystem = mkPerSystemOption (
|
||||
{
|
||||
system,
|
||||
config,
|
||||
lib,
|
||||
options,
|
||||
pkgs,
|
||||
self',
|
||||
...
|
||||
}@args:
|
||||
let
|
||||
inherit (lib.types)
|
||||
attrsOf
|
||||
listOf
|
||||
submoduleWith
|
||||
raw
|
||||
;
|
||||
in
|
||||
{
|
||||
options.pkgs-init.overlays = lib.mkOption {
|
||||
description = ''
|
||||
List of nixpkgs overlays (functions of the form: final: prev: { ... }).
|
||||
Any module can append. Order matters.
|
||||
'';
|
||||
default = [ ];
|
||||
example = [
|
||||
(final: prev: {
|
||||
my-hello = prev.hello;
|
||||
})
|
||||
];
|
||||
type = lib.types.listOf lib.types.unspecified;
|
||||
};
|
||||
options.pkgs-init.importArgs = lib.mkOption {
|
||||
description = "Extra arguments merged into the nixpkgs import call.";
|
||||
default = { };
|
||||
type = lib.types.attrs;
|
||||
};
|
||||
config = {
|
||||
_module.args.pkgs = import inputs.nixpkgs (
|
||||
{
|
||||
inherit system;
|
||||
overlays = config.pkgs-init.overlays;
|
||||
}
|
||||
// config.pkgs-init.importArgs
|
||||
);
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
20
nix/modules/python.nix
Normal file
20
nix/modules/python.nix
Normal file
@@ -0,0 +1,20 @@
|
||||
# Configures Python shell
|
||||
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, ... }:
|
||||
{
|
||||
make-shells.default = {
|
||||
packages = [
|
||||
pkgs.python313
|
||||
pkgs.uv
|
||||
pkgs.ruff
|
||||
pkgs.basedpyright
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${pkgs.python313}/lib
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
25
nix/modules/rust.nix
Normal file
25
nix/modules/rust.nix
Normal file
@@ -0,0 +1,25 @@
|
||||
# Configures Rust shell
|
||||
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, ... }:
|
||||
{
|
||||
pkgs-init.overlays = [
|
||||
inputs.fenix.overlays.default
|
||||
];
|
||||
|
||||
make-shells.default = {
|
||||
packages = [
|
||||
(pkgs.fenix.complete.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
pkgs.rustup # literally only added to make RustRover happy (otherwise useless)
|
||||
];
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -33,6 +33,9 @@ dependencies = [
|
||||
"cobs>=1.2.2",
|
||||
"loguru>=0.7.3",
|
||||
"textual>=5.3.0",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio>=4.10.0",
|
||||
"bidict>=0.23.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -61,8 +64,12 @@ darwin = [
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"scripts",
|
||||
"rust/exo_pyo3_bindings",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.9,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -87,7 +94,7 @@ reportUnnecessaryTypeIgnoreComment = "error"
|
||||
pythonVersion = "3.13"
|
||||
pythonPlatform = "Darwin"
|
||||
|
||||
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts"]
|
||||
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"]
|
||||
stubPath = "typings"
|
||||
|
||||
[[tool.basedpyright.executionEnvironments]]
|
||||
|
||||
@@ -4,47 +4,49 @@ set -euo pipefail
|
||||
###############################################################################
|
||||
# Args & prerequisites
|
||||
###############################################################################
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "Usage: $0 <PASSWORD> <git_command> [git_args...]" >&2
|
||||
if [[ $# -lt 1 ]]; then
|
||||
echo "Usage: $0 <git_command> [git_args...]" >&2
|
||||
echo "Examples:" >&2
|
||||
echo " $0 mypassword pull" >&2
|
||||
echo " $0 mypassword checkout main" >&2
|
||||
echo " $0 mypassword status" >&2
|
||||
echo " $0 mypassword fetch --all" >&2
|
||||
echo " $0 pull" >&2
|
||||
echo " $0 checkout main" >&2
|
||||
echo " $0 status" >&2
|
||||
echo " $0 fetch --all" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PASSWORD=$1
|
||||
shift # Remove password from args
|
||||
GIT_CMD="$*" # Remaining args form the git command
|
||||
HOSTS_FILE=${HOSTS_FILE:-hosts.json}
|
||||
|
||||
for prog in jq sshpass; do
|
||||
command -v "$prog" >/dev/null ||
|
||||
{ echo "Error: $prog not installed."; exit 1; }
|
||||
done
|
||||
GIT_CMD="$*" # All args form the git command
|
||||
HOSTS_FILE=${HOSTS_FILE:-hosts.txt}
|
||||
|
||||
###############################################################################
|
||||
# Load hosts.json (works on macOS Bash 3.2 and Bash 4+)
|
||||
# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+)
|
||||
###############################################################################
|
||||
if [[ ! -f "$HOSTS_FILE" ]]; then
|
||||
echo "Error: $HOSTS_FILE not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if builtin command -v mapfile >/dev/null 2>&1; then
|
||||
mapfile -t HOSTS < <(jq -r '.[]' "$HOSTS_FILE")
|
||||
mapfile -t HOSTS <"$HOSTS_FILE"
|
||||
else
|
||||
HOSTS=()
|
||||
while IFS= read -r h; do HOSTS+=("$h"); done < <(jq -r '.[]' "$HOSTS_FILE")
|
||||
while IFS= read -r h; do
|
||||
[[ -n "$h" ]] && HOSTS+=("$h")
|
||||
done <"$HOSTS_FILE"
|
||||
fi
|
||||
[[ ${#HOSTS[@]} -gt 0 ]] || { echo "No hosts found in $HOSTS_FILE"; exit 1; }
|
||||
[[ ${#HOSTS[@]} -gt 0 ]] || {
|
||||
echo "No hosts found in $HOSTS_FILE"
|
||||
exit 1
|
||||
}
|
||||
|
||||
###############################################################################
|
||||
# Helper – run a remote command and capture rc/stderr/stdout
|
||||
###############################################################################
|
||||
ssh_opts=(-o StrictHostKeyChecking=no
|
||||
-o NumberOfPasswordPrompts=1 # allow sshpass to answer exactly once
|
||||
-o LogLevel=ERROR)
|
||||
-o LogLevel=ERROR)
|
||||
|
||||
run_remote () { # $1 host $2 command
|
||||
run_remote() { # $1 host $2 command
|
||||
local host=$1 cmd=$2 rc
|
||||
if sshpass -p "$PASSWORD" ssh "${ssh_opts[@]}" "$host" "$cmd"; then
|
||||
if ssh "${ssh_opts[@]}" "$host" "$cmd"; then
|
||||
rc=0
|
||||
else
|
||||
rc=$?
|
||||
@@ -72,9 +74,9 @@ done
|
||||
wait
|
||||
|
||||
echo ""
|
||||
if (( fail == 0 )); then
|
||||
if ((fail == 0)); then
|
||||
echo "🎉 Git command executed successfully on all hosts!"
|
||||
else
|
||||
echo "⚠️ Some hosts failed—see above."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -4,38 +4,42 @@ set -euo pipefail
|
||||
###############################################################################
|
||||
# Args & prerequisites
|
||||
###############################################################################
|
||||
if [[ $# -lt 1 || $# -gt 2 ]]; then
|
||||
echo "Usage: $0 <PASSWORD> [hosts_file]" >&2 ; exit 1
|
||||
if [[ $# -gt 1 ]]; then
|
||||
echo "Usage: $0 [hosts_file]" >&2
|
||||
exit 1
|
||||
fi
|
||||
PASSWORD=$1
|
||||
HOSTS_FILE=${2:-hosts.json}
|
||||
|
||||
for prog in jq sshpass; do
|
||||
command -v "$prog" >/dev/null ||
|
||||
{ echo "Error: $prog not installed."; exit 1; }
|
||||
done
|
||||
HOSTS_FILE=${1:-hosts.txt}
|
||||
|
||||
###############################################################################
|
||||
# Load hosts.json (works on macOS Bash 3.2 and Bash 4+)
|
||||
# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+)
|
||||
###############################################################################
|
||||
if [[ ! -f "$HOSTS_FILE" ]]; then
|
||||
echo "Error: $HOSTS_FILE not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if builtin command -v mapfile >/dev/null 2>&1; then
|
||||
mapfile -t HOSTS < <(jq -r '.[]' "$HOSTS_FILE")
|
||||
mapfile -t HOSTS <"$HOSTS_FILE"
|
||||
else
|
||||
HOSTS=()
|
||||
while IFS= read -r h; do HOSTS+=("$h"); done < <(jq -r '.[]' "$HOSTS_FILE")
|
||||
while IFS= read -r h; do
|
||||
[[ -n "$h" ]] && HOSTS+=("$h")
|
||||
done <"$HOSTS_FILE"
|
||||
fi
|
||||
[[ ${#HOSTS[@]} -gt 0 ]] || { echo "No hosts found in $HOSTS_FILE"; exit 1; }
|
||||
[[ ${#HOSTS[@]} -gt 0 ]] || {
|
||||
echo "No hosts found in $HOSTS_FILE"
|
||||
exit 1
|
||||
}
|
||||
|
||||
###############################################################################
|
||||
# Helper – run a remote command and capture rc/stderr/stdout
|
||||
###############################################################################
|
||||
ssh_opts=(-o StrictHostKeyChecking=no
|
||||
-o NumberOfPasswordPrompts=1 # allow sshpass to answer exactly once
|
||||
-o LogLevel=ERROR)
|
||||
-o LogLevel=ERROR)
|
||||
|
||||
run_remote () { # $1 host $2 command
|
||||
run_remote() { # $1 host $2 command
|
||||
local host=$1 cmd=$2 rc
|
||||
if sshpass -p "$PASSWORD" ssh "${ssh_opts[@]}" "$host" "$cmd"; then
|
||||
if ssh "${ssh_opts[@]}" "$host" "$cmd"; then
|
||||
rc=0
|
||||
else
|
||||
rc=$?
|
||||
@@ -54,26 +58,42 @@ for h in "${HOSTS[@]}"; do
|
||||
) || fail=1 &
|
||||
done
|
||||
wait
|
||||
(( fail == 0 )) || { echo "❌ Some hosts could not be reached—check password or SSH access."; exit 1; }
|
||||
((fail == 0)) || {
|
||||
echo "❌ Some hosts could not be reached—check SSH access."
|
||||
exit 1
|
||||
}
|
||||
echo "✓ exo processes killed on all reachable hosts."
|
||||
|
||||
#
|
||||
###############################################################################
|
||||
# Phase 2 – start new exo processes (parallel, with sudo -S)
|
||||
# Phase 2 – cleanup database files (parallel)
|
||||
###############################################################################
|
||||
echo "=== Stage 2: starting new exo processes ==="
|
||||
echo "=== Stage 2: cleaning up database files ==="
|
||||
fail=0
|
||||
for i in "${!HOSTS[@]}"; do
|
||||
h=${HOSTS[$i]}
|
||||
|
||||
# one liner that pre-caches sudo and then runs the script
|
||||
if [[ $i -eq 0 ]]; then
|
||||
remote_cmd="cd ~/exo && ./run.sh -c"
|
||||
else
|
||||
remote_cmd="cd ~/exo && ./run.sh -rc"
|
||||
fi
|
||||
|
||||
( run_remote "$h" "$remote_cmd" ) || fail=1 &
|
||||
for h in "${HOSTS[@]}"; do
|
||||
(
|
||||
run_remote "$h" 'rm -f ~/.exo/*db* || true'
|
||||
) || fail=1 &
|
||||
done
|
||||
wait
|
||||
(( fail == 0 )) && echo "🎉 Deployment finished!" || \
|
||||
{ echo "⚠️ Some starts failed—see above."; exit 1; }
|
||||
((fail == 0)) || {
|
||||
echo "❌ Some hosts failed database cleanup."
|
||||
exit 1
|
||||
}
|
||||
echo "✓ Database files cleaned on all hosts."
|
||||
|
||||
###############################################################################
|
||||
# Phase 3 – start new exo processes in Terminal windows (parallel)
|
||||
###############################################################################
|
||||
echo "=== Stage 3: starting new exo processes ==="
|
||||
fail=0
|
||||
for h in "${HOSTS[@]}"; do
|
||||
# Use osascript to open Terminal windows on remote Mac
|
||||
remote_cmd="osascript -e \"tell app \\\"Terminal\\\" to do script \\\"cd ~/exo; nix develop --command uv run exo\\\"\""
|
||||
|
||||
(run_remote "$h" "$remote_cmd") || fail=1 &
|
||||
done
|
||||
wait
|
||||
((fail == 0)) && echo "🎉 Deployment finished!" || {
|
||||
echo "⚠️ Some starts failed—see above."
|
||||
exit 1
|
||||
}
|
||||
|
||||
15
rust/.gitignore
vendored
Normal file
15
rust/.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug
|
||||
target
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
# Generated by cargo mutants
|
||||
# Contains mutation testing data
|
||||
**/mutants.out*/
|
||||
165
rust/Cargo.toml
Normal file
165
rust/Cargo.toml
Normal file
@@ -0,0 +1,165 @@
|
||||
[workspace]
|
||||
resolver = "3"
|
||||
members = [
|
||||
"networking",
|
||||
"exo_pyo3_bindings",
|
||||
"system_custodian",
|
||||
"util",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.0.1"
|
||||
edition = "2024"
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 1
|
||||
debug = true
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
|
||||
# Common shared dependendencies configured once at the workspace
|
||||
# level, to be re-used more easily across workspace member crates.
|
||||
#
|
||||
# Common configurations include versions, paths, features, etc.
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "networking" }
|
||||
system_custodian = { path = "system_custodian" }
|
||||
util = { path = "util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
|
||||
# networking
|
||||
libp2p = "0.56"
|
||||
libp2p-tcp = "0.44"
|
||||
|
||||
[workspace.lints.rust]
|
||||
static_mut_refs = "warn" # Or use "warn" instead of deny
|
||||
incomplete_features = "allow"
|
||||
|
||||
# Clippy's lint category level configurations;
|
||||
# every member crate needs to inherit these by adding
|
||||
#
|
||||
# ```toml
|
||||
# [lints]
|
||||
# workspace = true
|
||||
# ```
|
||||
#
|
||||
# to their `Cargo.toml` files
|
||||
[workspace.lints.clippy]
|
||||
# Clippy lint categories meant to be enabled all at once
|
||||
correctness = { level = "deny", priority = -1 }
|
||||
suspicious = { level = "warn", priority = -1 }
|
||||
style = { level = "warn", priority = -1 }
|
||||
complexity = { level = "warn", priority = -1 }
|
||||
perf = { level = "warn", priority = -1 }
|
||||
pedantic = { level = "warn", priority = -1 }
|
||||
nursery = { level = "warn", priority = -1 }
|
||||
cargo = { level = "warn", priority = -1 }
|
||||
|
||||
# Individual Clippy lints from the `restriction` category
|
||||
arithmetic_side_effects = "warn"
|
||||
as_conversions = "warn"
|
||||
assertions_on_result_states = "warn"
|
||||
clone_on_ref_ptr = "warn"
|
||||
decimal_literal_representation = "warn"
|
||||
default_union_representation = "warn"
|
||||
deref_by_slicing = "warn"
|
||||
disallowed_script_idents = "deny"
|
||||
else_if_without_else = "warn"
|
||||
empty_enum_variants_with_brackets = "warn"
|
||||
empty_structs_with_brackets = "warn"
|
||||
error_impl_error = "warn"
|
||||
exit = "deny"
|
||||
expect_used = "warn"
|
||||
float_cmp_const = "warn"
|
||||
get_unwrap = "warn"
|
||||
if_then_some_else_none = "warn"
|
||||
impl_trait_in_params = "warn"
|
||||
indexing_slicing = "warn"
|
||||
infinite_loop = "warn"
|
||||
let_underscore_must_use = "warn"
|
||||
let_underscore_untyped = "warn"
|
||||
lossy_float_literal = "warn"
|
||||
mem_forget = "warn"
|
||||
missing_inline_in_public_items = "warn"
|
||||
multiple_inherent_impl = "warn"
|
||||
multiple_unsafe_ops_per_block = "warn"
|
||||
mutex_atomic = "warn"
|
||||
non_zero_suggestions = "warn"
|
||||
panic = "warn"
|
||||
partial_pub_fields = "warn"
|
||||
pattern_type_mismatch = "warn"
|
||||
pub_without_shorthand = "warn"
|
||||
rc_buffer = "warn"
|
||||
rc_mutex = "warn"
|
||||
redundant_type_annotations = "warn"
|
||||
renamed_function_params = "warn"
|
||||
rest_pat_in_fully_bound_structs = "warn"
|
||||
same_name_method = "warn"
|
||||
self_named_module_files = "deny"
|
||||
semicolon_inside_block = "warn"
|
||||
shadow_same = "warn"
|
||||
shadow_unrelated = "warn"
|
||||
str_to_string = "warn"
|
||||
string_add = "warn"
|
||||
string_lit_chars_any = "warn"
|
||||
string_to_string = "warn"
|
||||
tests_outside_test_module = "warn"
|
||||
todo = "warn"
|
||||
try_err = "warn"
|
||||
undocumented_unsafe_blocks = "warn"
|
||||
unnecessary_safety_comment = "warn"
|
||||
unnecessary_safety_doc = "warn"
|
||||
unneeded_field_pattern = "warn"
|
||||
unseparated_literal_suffix = "warn"
|
||||
unused_result_ok = "warn"
|
||||
unused_trait_names = "warn"
|
||||
unwrap_used = "warn"
|
||||
verbose_file_reads = "warn"
|
||||
2
rust/clippy.toml
Normal file
2
rust/clippy.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
77
rust/exo_pyo3_bindings/Cargo.toml
Normal file
77
rust/exo_pyo3_bindings/Cargo.toml
Normal file
@@ -0,0 +1,77 @@
|
||||
[package]
|
||||
name = "exo_pyo3_bindings"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
path = "src/lib.rs"
|
||||
name = "exo_pyo3_bindings"
|
||||
|
||||
# "cdylib" needed to produce shared library for Python to import
|
||||
# "rlib" needed for stub-gen to run
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/stub_gen.rs"
|
||||
name = "stub_gen"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!!
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.13.1" }
|
||||
pyo3-async-runtimes = { version = "0.25", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
#tracing = "0.1"
|
||||
#tracing-subscriber = "0.3"
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
pyo3-log = "0.12"
|
||||
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
1
rust/exo_pyo3_bindings/README.md
Normal file
1
rust/exo_pyo3_bindings/README.md
Normal file
@@ -0,0 +1 @@
|
||||
TODO: do something here....
|
||||
207
rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi
Normal file
207
rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi
Normal file
@@ -0,0 +1,207 @@
|
||||
# This file is automatically generated by pyo3_stub_gen
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
from enum import Enum
|
||||
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
class Keypair:
|
||||
r"""
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes:bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes:bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes:bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes:bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n:builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes:bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string:builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
"""
|
||||
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity:Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit:builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic:builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
|
||||
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
"""
|
||||
async def gossipsub_unsubscribe(self, topic:builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Unsubscribes from a `GossipSub` topic.
|
||||
|
||||
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
"""
|
||||
async def gossipsub_publish(self, topic:builtins.str, data:bytes) -> None:
|
||||
r"""
|
||||
Publishes a message with multiple topics to the `GossipSub` network.
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit:builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __new__(cls, *args) -> NoPeersSubscribedToTopicError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes:bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
class ConnectionUpdateType(Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
|
||||
32
rust/exo_pyo3_bindings/pyproject.toml
Normal file
32
rust/exo_pyo3_bindings/pyproject.toml
Normal file
@@ -0,0 +1,32 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" }
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"exo_pyo3_bindings",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.maturin]
|
||||
#purelib = true
|
||||
#python-source = "python"
|
||||
module-name = "exo_pyo3_bindings"
|
||||
features = ["pyo3/extension-module", "pyo3/experimental-async"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
log_cli = true
|
||||
log_cli_level = "INFO"
|
||||
asyncio_mode = "auto"
|
||||
40
rust/exo_pyo3_bindings/src/allow_threading.rs
Normal file
40
rust/exo_pyo3_bindings/src/allow_threading.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::{Pin, pin},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
#[pin_project]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct AllowThreads<F>(#[pin] F);
|
||||
|
||||
impl<F> AllowThreads<F>
|
||||
where
|
||||
Self: Future,
|
||||
{
|
||||
pub fn new(f: F) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
}
|
||||
}
|
||||
8
rust/exo_pyo3_bindings/src/bin/stub_gen.rs
Normal file
8
rust/exo_pyo3_bindings/src/bin/stub_gen.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
use pyo3_stub_gen::Result;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
|
||||
let stub = exo_pyo3_bindings::stub_info()?;
|
||||
stub.generate()?;
|
||||
Ok(())
|
||||
}
|
||||
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
217
rust/exo_pyo3_bindings/src/lib.rs
Normal file
217
rust/exo_pyo3_bindings/src/lib.rs
Normal file
@@ -0,0 +1,217 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = ResultExt)]
|
||||
impl<T, E> Result<T, E>
|
||||
where
|
||||
E: ToString,
|
||||
{
|
||||
fn pyerr(self) -> PyResult<T> {
|
||||
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
fn allow_threads_py(self) -> AllowThreads<Self>
|
||||
where
|
||||
AllowThreads<Self>: Future,
|
||||
{
|
||||
AllowThreads::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Future> FutureExt for T {}
|
||||
|
||||
#[ext(pub, name = PyErrExt)]
|
||||
impl PyErr {
|
||||
fn receiver_channel_closed() -> Self {
|
||||
PyConnectionError::new_err("Receiver channel closed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
match self {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
// write error back to python
|
||||
e.write_unraisable(py, None);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioRuntimeExt)]
|
||||
impl Runtime {
|
||||
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
|
||||
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscSenderExt)]
|
||||
impl<T> mpsc::Sender<T> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the
|
||||
/// channel has not hung up already. An unsuccessful send would be one where
|
||||
/// the corresponding receiver has already been closed.
|
||||
async fn send_py(&self, value: T) -> PyResult<()> {
|
||||
self.send(value)
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscReceiverExt)]
|
||||
impl<T> mpsc::Receiver<T> {
|
||||
/// Receives the next value for this receiver.
|
||||
async fn recv_py(&mut self) -> PyResult<T> {
|
||||
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` values for this receiver and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
|
||||
// get updates from receiver channel
|
||||
let mut updates = Vec::with_capacity(limit);
|
||||
let received = self.recv_many(&mut updates, limit).await;
|
||||
|
||||
// if we received zero items, then the channel was unexpectedly closed
|
||||
if limit != 0 && received == 0 {
|
||||
return Err(PyErr::receiver_channel_closed());
|
||||
}
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
/// Tries to receive the next value for this receiver.
|
||||
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
|
||||
match self.try_recv() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(TryRecvError::Empty) => Ok(None),
|
||||
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule(name = "exo_pyo3_bindings")]
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
define_stub_info_gatherer!(stub_info);
|
||||
534
rust/exo_pyo3_bindings/src/networking.rs
Normal file
534
rust/exo_pyo3_bindings/src/networking.rs
Normal file
@@ -0,0 +1,534 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use util::ext::VecExt as _;
|
||||
|
||||
mod exception {
|
||||
use pyo3::{exceptions::{PyException}, prelude::*, PyErrArguments};
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3_stub_gen::{derive::*};
|
||||
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
pub struct PyNoPeersSubscribedToTopicError {}
|
||||
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
const MSG: &'static str = "\
|
||||
No peers are currently subscribed to receive messages on this topic. \
|
||||
Wait for peers to subscribe or check your network connectivity.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNetworkingHandle {
|
||||
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
|
||||
// immediately beforehand to release the interpreter.
|
||||
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
|
||||
// ---- Lifecycle management methods ----
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & convert any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
|
||||
/// Publishes a message with multiple topics to the `GossipSub` network.
|
||||
///
|
||||
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors => ignore messageID for now!!!
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.map(|(t, d)| (t, d.pybytes())))
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::drop;
|
||||
use core::option::Option::Some;
|
||||
use core::time::Duration;
|
||||
use tokio;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_channel() {
|
||||
struct Ping;
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Ping>(10);
|
||||
|
||||
let _ = tokio::spawn(async move {
|
||||
println!("TASK: entered");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.recv() => {
|
||||
match result {
|
||||
Some(_) => {
|
||||
println!("TASK: pinged");
|
||||
}
|
||||
None => {
|
||||
println!("TASK: closing channel");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
|
||||
println!("TASK: heartbeat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("TASK: exited");
|
||||
});
|
||||
|
||||
let tx2 = tx.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx.send(Ping).await.expect("Should not fail");
|
||||
drop(tx);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx2.send(Ping).await.expect("Should not fail");
|
||||
drop(tx2);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
}
|
||||
}
|
||||
34
rust/exo_pyo3_bindings/tests/test_python.py
Normal file
34
rust/exo_pyo3_bindings/tests/test_python.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep_on_multiple_items() -> None:
|
||||
print("PYTHON: starting handle")
|
||||
h = NetworkingHandle(Keypair.generate_ed25519())
|
||||
|
||||
ct = asyncio.create_task(_await_cons(h))
|
||||
mt = asyncio.create_task(_await_msg(h))
|
||||
|
||||
# sleep for 4 ticks
|
||||
for i in range(4):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
await h.gossipsub_publish("topic", b"somehting or other")
|
||||
except NoPeersSubscribedToTopicError as e:
|
||||
print("caught it", e)
|
||||
|
||||
|
||||
async def _await_cons(h: NetworkingHandle):
|
||||
while True:
|
||||
c = await h.connection_update_recv()
|
||||
print(f"PYTHON: connection update: {c}")
|
||||
|
||||
|
||||
async def _await_msg(h: NetworkingHandle):
|
||||
while True:
|
||||
m = await h.gossipsub_recv()
|
||||
print(f"PYTHON: message: {m}")
|
||||
44
rust/networking/Cargo.toml
Normal file
44
rust/networking/Cargo.toml
Normal file
@@ -0,0 +1,44 @@
|
||||
[package]
|
||||
name = "networking"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "networking"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
74
rust/networking/examples/chatroom.rs
Normal file
74
rust/networking/examples/chatroom.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use futures::stream::StreamExt as _;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
// on gossipsub incoming
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
130
rust/networking/examples/chatroom_manual.rs
Normal file
130
rust/networking/examples/chatroom_manual.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use std::{
|
||||
error::Error,
|
||||
hash::{Hash},
|
||||
};
|
||||
use std::time::Duration;
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
@@ -0,0 +1,44 @@
|
||||
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
|
||||
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
|
||||
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
|
||||
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
|
||||
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
|
||||
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
|
||||
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
|
||||
|
||||
https://stackoverflow.com/questions/17169298/af-packet-on-osx
|
||||
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
|
||||
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
|
||||
|
||||
https://www.unix.com/man_page/mojave/4/ip/
|
||||
^OS X manpages for IP
|
||||
|
||||
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
|
||||
^driver kit, system extensions & kexts for macOS
|
||||
|
||||
----
|
||||
|
||||
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
|
||||
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
|
||||
----> here is a guide on how to set up thunderbolt-ethernet on linux
|
||||
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
|
||||
|
||||
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
|
||||
^GPT discussion about making socket interface
|
||||
|
||||
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
|
||||
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
|
||||
^GPT discussion about accessing TB on MacOS low level interactions
|
||||
|
||||
--------------------------------
|
||||
|
||||
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
|
||||
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
|
||||
|
||||
|
||||
---------------------------------
|
||||
|
||||
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
|
||||
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
|
||||
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
|
||||
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c
|
||||
379
rust/networking/src/discovery.rs
Normal file
379
rust/networking/src/discovery.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{dummy, CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
use crate::ext::MultiaddrExt;
|
||||
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(ping::Config::new().with_timeout(PING_TIMEOUT).with_interval(PING_INTERVAL))
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => self.managed.on_connection_handler_event(
|
||||
peer_id,
|
||||
connection_id,
|
||||
ev,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
}
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
}
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
44
rust/networking/src/keep_alive.rs
Normal file
44
rust/networking/src/keep_alive.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
64
rust/networking/src/lib.rs
Normal file
64
rust/networking/src/lib.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use std::net::IpAddr;
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
/// If the multiaddress corresponds to a TCP address, extracts it
|
||||
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
|
||||
let mut ps = self.into_iter();
|
||||
let ip = if let Some(p) = ps.next() {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
let Some(Protocol::Tcp(port)) = ps.next() else {
|
||||
return None;
|
||||
};
|
||||
Some((ip, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
133
rust/networking/src/swarm.rs
Normal file
133
rust/networking/src/swarm.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
///
|
||||
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
|
||||
/// even be, and how to inject the right version into this config/initialization. E.g. should
|
||||
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
|
||||
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
.with_behaviour(Behaviour::new)?
|
||||
.build();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::NETWORK_VERSION;
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
const PNET_PRESHARED_KEY: [u8; 32] = Sha3_256::new()
|
||||
.update(b"exo_discovery_network")
|
||||
.update(NETWORK_VERSION)
|
||||
.finalize();
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
/// also different-versioned instances of this same network.
|
||||
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
|
||||
async fn pnet_upgrade<TSocket>(
|
||||
socket: TSocket,
|
||||
_: impl Sized,
|
||||
) -> Result<PnetOutput<TSocket>, PnetError>
|
||||
where
|
||||
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
use pnet::{PnetConfig, PreSharedKey};
|
||||
PnetConfig::new(PreSharedKey::new(PNET_PRESHARED_KEY))
|
||||
.handshake(socket)
|
||||
.await
|
||||
}
|
||||
|
||||
/// TCP/IP transport layer configuration.
|
||||
pub fn tcp_transport(
|
||||
keypair: &identity::Keypair,
|
||||
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
|
||||
use libp2p::{
|
||||
core::upgrade::Version,
|
||||
tcp::{Config, tokio},
|
||||
};
|
||||
|
||||
// `TCP_NODELAY` enabled => avoid latency
|
||||
let tcp_config = Config::default().nodelay(true);
|
||||
|
||||
// V1 + lazy flushing => 0-RTT negotiation
|
||||
let upgrade_version = Version::V1Lazy;
|
||||
|
||||
// Noise is faster than TLS + we don't care much for security
|
||||
let noise_config = noise::Config::new(keypair)?;
|
||||
|
||||
// Use default Yamux config for multiplexing
|
||||
let yamux_config = yamux::Config::default();
|
||||
|
||||
// Create new Tokio-driven TCP/IP transport layer
|
||||
let base_transport = tokio::Transport::new(tcp_config)
|
||||
.and_then(pnet_upgrade)
|
||||
.upgrade(upgrade_version)
|
||||
.authenticate(noise_config)
|
||||
.multiplex(yamux_config);
|
||||
|
||||
// Return boxed transport (to flatten complex type)
|
||||
Ok(base_transport.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
// => signed message authenticity + strict validation mode means the message-ID is
|
||||
// automatically provided by gossipsub w/out needing to provide custom message-ID function
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
)
|
||||
.expect("creating gossipsub behavior should always work")
|
||||
}
|
||||
}
|
||||
7
rust/networking/tests/dummy.rs
Normal file
7
rust/networking/tests/dummy.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
// maybe this will hold test in the future...??
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn does_nothing() {}
|
||||
}
|
||||
2
rust/rust-toolchain.toml
Normal file
2
rust/rust-toolchain.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
47
rust/system_custodian/Cargo.toml
Normal file
47
rust/system_custodian/Cargo.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
4
rust/system_custodian/src/bin/main.rs
Normal file
4
rust/system_custodian/src/bin/main.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
69
rust/system_custodian/src/lib.rs
Normal file
69
rust/system_custodian/src/lib.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
25
rust/util/Cargo.toml
Normal file
25
rust/util/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "util"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "util"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
thiserror = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
internment = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
bon = { workspace = true }
|
||||
recursion = { workspace = true }
|
||||
53
rust/util/src/lib.rs
Normal file
53
rust/util/src/lib.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub mod nonempty;
|
||||
pub mod wakerdeque;
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub mod ext {
|
||||
use extend::ext;
|
||||
|
||||
#[ext(pub, name = BoxedSliceExt)]
|
||||
impl<T> Box<[T]> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Box<[B]>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = VecExt)]
|
||||
impl<T> Vec<T> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Vec<B>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
138
rust/util/src/nonempty.rs
Normal file
138
rust/util/src/nonempty.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::slice::SliceIndex;
|
||||
use std::{ops, slice};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")]
|
||||
pub struct EmptySliceError;
|
||||
|
||||
/// A pointer to a non-empty fixed-size slice allocated on the heap.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
#[repr(transparent)]
|
||||
pub struct NonemptyArray<T>(Box<[T]>);
|
||||
|
||||
#[allow(clippy::arbitrary_source_item_ordering)]
|
||||
impl<T> NonemptyArray<T> {
|
||||
#[inline]
|
||||
pub fn singleton(value: T) -> Self {
|
||||
Self(Box::new([value]))
|
||||
}
|
||||
|
||||
#[allow(clippy::missing_errors_doc)]
|
||||
#[inline]
|
||||
pub fn try_from_boxed_slice<S: Into<Box<[T]>>>(
|
||||
boxed_slice: S,
|
||||
) -> Result<Self, EmptySliceError> {
|
||||
let boxed_slice = boxed_slice.into();
|
||||
if boxed_slice.is_empty() {
|
||||
Err(EmptySliceError)
|
||||
} else {
|
||||
Ok(Self(boxed_slice))
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn into_boxed_slice(self) -> Box<[T]> {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn to_vec(&self) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.0.to_vec()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn as_slice(&self) -> &[T] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn first(&self) -> &T {
|
||||
&self.0[0]
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn last(&self) -> &T {
|
||||
&self.0[self.0.len() - 1]
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn get<I>(&self, index: I) -> Option<&I::Output>
|
||||
where
|
||||
I: SliceIndex<[T]>,
|
||||
{
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter(&self) -> slice::Iter<'_, T> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> NonemptyArray<U> {
|
||||
NonemptyArray(self.0.into_iter().map(f).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<NonemptyArray<T>> for Box<[T]> {
|
||||
#[inline]
|
||||
fn from(value: NonemptyArray<T>) -> Self {
|
||||
value.into_boxed_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ops::Index<usize> for NonemptyArray<T> {
|
||||
type Output = T;
|
||||
|
||||
#[inline]
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
self.0.index(index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoIterator for NonemptyArray<T> {
|
||||
type Item = T;
|
||||
type IntoIter = std::vec::IntoIter<T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.into_boxed_slice().into_vec().into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> IntoIterator for &'a NonemptyArray<T> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = slice::Iter<'a, T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
55
rust/util/src/wakerdeque.rs
Normal file
55
rust/util/src/wakerdeque.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
|
||||
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
|
||||
pub struct WakerDeque<T> {
|
||||
waker: Option<Waker>,
|
||||
deque: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Debug for WakerDeque<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.deque.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WakerDeque<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
waker: None,
|
||||
deque: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, cx: &mut Context<'_>) {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
fn wake(&mut self) {
|
||||
let Some(ref mut w) = self.waker else { return };
|
||||
w.wake_by_ref();
|
||||
self.waker = None;
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_back()
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_front(value);
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_back(value);
|
||||
}
|
||||
}
|
||||
@@ -8,9 +8,10 @@ import mlx.nn as nn # type: ignore
|
||||
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
|
||||
def __call__(self, x: mx.array, cache: Optional[list[KVCache]]) -> mx.array: ...
|
||||
|
||||
|
||||
@@ -18,7 +19,7 @@ class Detokenizer:
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
|
||||
|
||||
@property
|
||||
def last_segment(self) -> str: ...
|
||||
|
||||
@@ -27,5 +28,5 @@ class TokenizerWrapper:
|
||||
bos_token: Optional[str]
|
||||
eos_token_ids: list[int]
|
||||
detokenizer: Detokenizer
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
|
||||
|
||||
@@ -29,6 +29,7 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
mlx_rank: None | int = None
|
||||
mlx_world_size: None | int = None
|
||||
|
||||
|
||||
def mx_barrier():
|
||||
mx.eval( # type: ignore
|
||||
mx.distributed.all_sum(
|
||||
@@ -36,6 +37,7 @@ def mx_barrier():
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int) -> int:
|
||||
if mlx_rank is None:
|
||||
return value
|
||||
@@ -46,8 +48,9 @@ def broadcast_from_zero(value: int) -> int:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu))
|
||||
mx.eval(m) # type: ignore
|
||||
return int(m.item()) # type: ignore
|
||||
mx.eval(m) # type: ignore
|
||||
return int(m.item()) # type: ignore
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
@@ -83,7 +86,7 @@ def mlx_setup(
|
||||
if wired_frac_of_mrwss > 0.0:
|
||||
target_wired = int(wired_frac_of_mrwss * mrwss)
|
||||
target_wired = min(target_wired, target_cache) # don’t wire more than cache
|
||||
|
||||
|
||||
runner_print(f"{target_wired=}")
|
||||
with contextlib.suppress(Exception): # older macOS won’t have this
|
||||
mx.set_wired_limit(max(target_wired, 0))
|
||||
@@ -136,14 +139,14 @@ def initialize_mlx(
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
model_shard_meta: ShardMetadata,
|
||||
model_shard_meta: ShardMetadata,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(model_shard_meta.model_meta.model_id)
|
||||
|
||||
runner_print(f"loading model from {model_path}")
|
||||
|
||||
model, config = load_model(model_path, lazy=True, strict=False) # type: ignore
|
||||
runner_print(f'{config=}')
|
||||
runner_print(f"{config=}")
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
@@ -154,7 +157,7 @@ def shard_and_load(
|
||||
# Synchronize processes before generation to avoid timeout
|
||||
mx_barrier()
|
||||
|
||||
return model, tokenizer # type: ignore
|
||||
return model, tokenizer # type: ignore
|
||||
|
||||
|
||||
async def apply_chat_template(
|
||||
@@ -199,11 +202,13 @@ async def apply_chat_template(
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
"""
|
||||
A KVCache that pretends to exist but holds zero tokens.
|
||||
It satisfies .state/.meta_state and never allocates real keys/values.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype: mx.Dtype = mx.float16):
|
||||
super().__init__()
|
||||
# zero-length K/V so shapes/dtypes are defined but empty
|
||||
@@ -218,19 +223,21 @@ class NullKVCache(KVCache):
|
||||
|
||||
@state.setter
|
||||
def state(self, v: tuple[mx.array, mx.array]) -> None:
|
||||
raise NotImplementedError('We should not be setting a NullKVCache.')
|
||||
raise NotImplementedError("We should not be setting a NullKVCache.")
|
||||
|
||||
|
||||
async def make_kv_cache(
|
||||
model: Model,
|
||||
max_kv_size: Optional[int] = None,
|
||||
) -> list[KVCache]:
|
||||
assert hasattr(model, 'layers')
|
||||
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
return [
|
||||
NullKVCache() if isinstance(layer, IdentityLayer) else KVCache()
|
||||
for layer in model.layers
|
||||
]
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
"""
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
|
||||
236
src/exo/main.py
236
src/exo/main.py
@@ -1,41 +1,221 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.master.main import main as master_main
|
||||
import exo.routing.topics as topics
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
from exo.shared.constants import EXO_LOG
|
||||
from exo.shared.election import Election, ElectionResult
|
||||
from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.worker.main import main as worker_main
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
# TODO: Entrypoint refactor
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
# This is the collection of systems for our entire application.
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
worker: Worker
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
master: Master | None
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
_tg: TaskGroup | None = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
await router.register_topic(topics.LOCAL_EVENTS)
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id=node_id,
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
else:
|
||||
api = None
|
||||
|
||||
worker = Worker(
|
||||
node_id,
|
||||
exo_shard_downloader(),
|
||||
initial_connection_messages=[],
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
tb_only=args.tb_only,
|
||||
)
|
||||
|
||||
# If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ.
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
election = Election(
|
||||
node_id,
|
||||
seniority=1_000_000 if args.force_master else 0,
|
||||
# nb: this DOES feedback right now. i have thoughts on how to address this,
|
||||
# but ultimately it seems not worth the complexity
|
||||
election_message_sender=router.sender(topics.ELECTION_MESSAGES),
|
||||
election_message_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(router, worker, election, er_recv, master, api, node_id)
|
||||
|
||||
async def run(self):
|
||||
async with anyio.create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.worker.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.master:
|
||||
tg.start_soon(self.master.run)
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
|
||||
async def _elect_loop(self):
|
||||
assert self._tg
|
||||
with self.election_result_receiver as results:
|
||||
async for result in results:
|
||||
# I don't like this duplication, but it's manageable for now.
|
||||
# TODO: This function needs refactoring generally
|
||||
|
||||
# Ok:
|
||||
# On new master:
|
||||
# - Elect master locally if necessary
|
||||
# - Shutdown and re-create the worker
|
||||
# - Shut down and re-create the API
|
||||
|
||||
if result.node_id == self.node_id and self.master is not None:
|
||||
logger.info("Node elected Master")
|
||||
elif result.node_id == self.node_id and self.master is None:
|
||||
logger.info("Node elected Master - promoting self")
|
||||
self.master = Master(
|
||||
self.node_id,
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif result.node_id != self.node_id and self.master is not None:
|
||||
logger.info(f"Node {result.node_id} elected master - demoting self")
|
||||
await self.master.shutdown()
|
||||
self.master = None
|
||||
else:
|
||||
logger.info(f"Node {result.node_id} elected master")
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
exo_shard_downloader(),
|
||||
initial_connection_messages=result.historic_messages,
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
global_event_receiver=self.router.receiver(
|
||||
topics.GLOBAL_EVENTS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
self.api.reset()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="exo")
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_const", const=1, dest="verbosity", default=0
|
||||
)
|
||||
parser.add_argument(
|
||||
"-vv",
|
||||
"--very-verbose",
|
||||
action="store_const",
|
||||
const=2,
|
||||
dest="verbosity",
|
||||
default=0,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if type(args.verbosity) is not int: # type: ignore
|
||||
raise TypeError("Verbosity was parsed incorrectly")
|
||||
args = Args.parse()
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("starting exo")
|
||||
logger.info("Starting EXO")
|
||||
|
||||
# This is for future PyInstaller compatibility
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
worker = mp.Process(target=worker_main, args=(EXO_LOG, args.verbosity))
|
||||
master = mp.Process(target=master_main, args=(EXO_LOG, args.verbosity))
|
||||
worker.start()
|
||||
master.start()
|
||||
worker.join()
|
||||
master.join()
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
|
||||
logger_cleanup()
|
||||
|
||||
|
||||
class Args(CamelCaseModel):
|
||||
verbosity: int = 0
|
||||
force_master: bool = False
|
||||
spawn_api: bool = False
|
||||
api_port: PositiveInt = 8000
|
||||
tb_only: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
parser = argparse.ArgumentParser(prog="EXO")
|
||||
default_verbosity = 0
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quiet",
|
||||
action="store_const",
|
||||
const=-1,
|
||||
dest="verbosity",
|
||||
default=default_verbosity,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="count",
|
||||
dest="verbosity",
|
||||
default=default_verbosity,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--force-master",
|
||||
action="store_true",
|
||||
dest="force_master",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-api",
|
||||
action="store_false",
|
||||
dest="spawn_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-port",
|
||||
type=int,
|
||||
dest="api_port",
|
||||
default=8000,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tb-only",
|
||||
action="store_true",
|
||||
dest="tb_only",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -2,16 +2,18 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Callable, List, Sequence, final
|
||||
from typing import final
|
||||
|
||||
import uvicorn
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
@@ -24,23 +26,26 @@ from exo.shared.types.api import (
|
||||
ModelListModel,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.events.chunks import TokenChunk
|
||||
from exo.shared.types.events.commands import (
|
||||
ChatCompletionCommand,
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CommandType,
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
TaskFinishedCommand,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
TaggedCommand,
|
||||
# TODO: SpinUpInstance
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.events.components import EventFromEventLog
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.shared.types.worker.instances import Instance
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
|
||||
@@ -70,26 +75,50 @@ async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
command_buffer: List[Command],
|
||||
global_events: AsyncSQLiteEventStorage,
|
||||
get_state: Callable[[], State],
|
||||
*,
|
||||
node_id: NodeId,
|
||||
port: int = 8000,
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
) -> None:
|
||||
self.get_state = get_state
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
self.state = State()
|
||||
self.command_sender = command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.port = port
|
||||
|
||||
self._app = FastAPI()
|
||||
self.app = FastAPI()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
self._app.mount(
|
||||
self.app.mount(
|
||||
"/",
|
||||
StaticFiles(directory=os.environ["DASHBOARD_DIR"], html=True),
|
||||
StaticFiles(
|
||||
directory=os.environ.get(
|
||||
"DASHBOARD_DIR",
|
||||
os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../dashboard")
|
||||
),
|
||||
),
|
||||
html=True,
|
||||
),
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, asyncio.Queue[ChunkGenerated]
|
||||
] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self):
|
||||
self.state = State()
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._chat_completion_queues = {}
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self._app.add_middleware(
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
@@ -98,23 +127,19 @@ class API:
|
||||
)
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
self._app.post("/instance")(self.create_instance)
|
||||
self._app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self._app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self._app.get("/models")(self.get_models)
|
||||
self._app.get("/v1/models")(self.get_models)
|
||||
self._app.post("/v1/chat/completions")(self.chat_completions)
|
||||
self._app.get("/state")(self.get_state)
|
||||
|
||||
@property
|
||||
def app(self) -> FastAPI:
|
||||
return self._app
|
||||
self.app.post("/instance")(self.create_instance)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions")(self.chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceTaskParams
|
||||
) -> CreateInstanceResponse:
|
||||
model_meta = await resolve_model_meta(payload.model_id)
|
||||
required_memory_bytes = model_meta.storage_size_kilobytes * 1024
|
||||
required_memory_bytes = model_meta.storage_size.in_kb
|
||||
available_memory_bytes = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory_bytes > available_memory_bytes:
|
||||
@@ -123,37 +148,33 @@ class API:
|
||||
detail=f"Insufficient memory to create instance. Required: {required_memory_bytes // (1024**3):.1f}GB, Available: {available_memory_bytes // (1024**3):.1f}GB",
|
||||
)
|
||||
|
||||
command = CreateInstanceCommand(
|
||||
command = CreateInstance(
|
||||
command_id=CommandId(),
|
||||
command_type=CommandType.CREATE_INSTANCE,
|
||||
model_meta=model_meta,
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
self.command_buffer.append(command)
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=model_meta,
|
||||
instance_id=command.instance_id,
|
||||
)
|
||||
|
||||
def get_instance(self, instance_id: InstanceId) -> Instance:
|
||||
state = self.get_state()
|
||||
state = self.state
|
||||
if instance_id not in state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
return state.instances[instance_id]
|
||||
|
||||
def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
|
||||
if instance_id not in self.get_state().instances:
|
||||
async def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
|
||||
if instance_id not in self.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
command = DeleteInstanceCommand(
|
||||
command = DeleteInstance(
|
||||
command_id=CommandId(),
|
||||
command_type=CommandType.DELETE_INSTANCE,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
self.command_buffer.append(command)
|
||||
await self._send(command)
|
||||
return DeleteInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
@@ -165,37 +186,27 @@ class API:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
events = await self.global_events.get_events_since(0)
|
||||
prev_idx = await self.global_events.get_last_idx()
|
||||
self._chat_completion_queues[command_id] = asyncio.Queue()
|
||||
|
||||
finished = False
|
||||
while not finished:
|
||||
await asyncio.sleep(0.01)
|
||||
# TODO: how long should this timeout be?
|
||||
chunk = await asyncio.wait_for(
|
||||
self._chat_completion_queues[command_id].get(), timeout=60
|
||||
)
|
||||
if chunk.command_id == command_id:
|
||||
assert isinstance(chunk.chunk, TokenChunk)
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(chunk.chunk)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
events: Sequence[
|
||||
EventFromEventLog[Event]
|
||||
] = await self.global_events.get_events_since(prev_idx)
|
||||
# TODO: Can do this with some better functionality to tail event log into an AsyncGenerator.
|
||||
prev_idx = events[-1].idx_in_log if events else prev_idx
|
||||
if chunk.chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
finished = True
|
||||
|
||||
for wrapped_event in events:
|
||||
event = wrapped_event.event
|
||||
if isinstance(event, ChunkGenerated) and event.command_id == command_id:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
event.chunk
|
||||
)
|
||||
logger.debug(chunk_response)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if event.chunk.finish_reason is not None:
|
||||
yield "data: [DONE]"
|
||||
finished = True
|
||||
|
||||
command = TaskFinishedCommand(command_id=command_id)
|
||||
self.command_buffer.append(command)
|
||||
|
||||
return
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
@@ -210,6 +221,7 @@ class API:
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
# Preprocess messages for GPT-OSS harmony format if needed
|
||||
# TODO: This is slop surely we get rid
|
||||
if "gpt-oss" in payload.model.lower():
|
||||
import re
|
||||
|
||||
@@ -233,7 +245,7 @@ class API:
|
||||
# Store thinking in the thinking field
|
||||
message.thinking = thinking_match.group(1).strip()
|
||||
|
||||
for instance in self.get_state().instances.values():
|
||||
for instance in self.state.instances.values():
|
||||
if instance.shard_assignments.model_id == payload.model:
|
||||
break
|
||||
else:
|
||||
@@ -242,23 +254,22 @@ class API:
|
||||
status_code=404, detail=f"No instance found for model {payload.model}"
|
||||
)
|
||||
|
||||
command = ChatCompletionCommand(
|
||||
command = ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
command_type=CommandType.CHAT_COMPLETION,
|
||||
request_params=payload,
|
||||
)
|
||||
self.command_buffer.append(command)
|
||||
await self._send(command)
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id), media_type="text/plain"
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> int:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
state = self.get_state()
|
||||
total_available = 0
|
||||
|
||||
for node_profile in state.node_profiles.values():
|
||||
total_available += node_profile.memory.ram_available
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available.in_bytes
|
||||
|
||||
return total_available
|
||||
|
||||
@@ -277,14 +288,35 @@ class API:
|
||||
]
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
uvicorn_config = uvicorn.Config(
|
||||
self.app, host="0.0.0.0", port=self.port, access_log=False
|
||||
)
|
||||
uvicorn_server = uvicorn.Server(uvicorn_config)
|
||||
|
||||
def start_fastapi_server(
|
||||
command_buffer: List[Command],
|
||||
global_events: AsyncSQLiteEventStorage,
|
||||
get_state: Callable[[], State],
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
):
|
||||
api = API(command_buffer, global_events, get_state)
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(uvicorn_server.serve)
|
||||
tg.start_soon(self._apply_state)
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
uvicorn.run(api.app, host=host, port=port)
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for event in events:
|
||||
self.event_buffer.ingest(event.origin_idx, event.tagged_event.c)
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
self._chat_completion_queues[event.command_id].put_nowait(event)
|
||||
|
||||
async def _send(self, command: Command):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.node_id, tagged_command=TaggedCommand.from_(command)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
|
||||
|
||||
|
||||
class ElectionCallbacks:
|
||||
"""
|
||||
Simple callbacks for the Rust election system to invoke.
|
||||
No event system involvement - just direct forwarder control.
|
||||
"""
|
||||
|
||||
def __init__(self, forwarder_supervisor: ForwarderSupervisor):
|
||||
self._forwarder_supervisor = forwarder_supervisor
|
||||
|
||||
async def on_became_master(self) -> None:
|
||||
"""Called when this node is elected as master"""
|
||||
logger.info("Node elected as master")
|
||||
await self._forwarder_supervisor.notify_role_change(ForwarderRole.MASTER)
|
||||
|
||||
async def on_became_replica(self) -> None:
|
||||
"""Called when this node becomes a replica"""
|
||||
logger.info("Node demoted to replica")
|
||||
await self._forwarder_supervisor.notify_role_change(ForwarderRole.REPLICA)
|
||||
@@ -1,9 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from exo.shared.env import BaseEnv
|
||||
|
||||
|
||||
class MasterEnvironmentSchema(BaseEnv):
|
||||
# Master-specific: forwarder configuration
|
||||
# Default to build/forwarder if not explicitly set
|
||||
FORWARDER_BINARY_PATH: Path = Path("build/forwarder")
|
||||
@@ -10,7 +10,7 @@ from exo.shared.constants import (
|
||||
EXO_GLOBAL_EVENT_DB,
|
||||
EXO_WORKER_EVENT_DB,
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC,
|
||||
LIBP2P_WORKER_EVENTS_TOPIC,
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
@@ -58,9 +58,7 @@ class ForwarderSupervisor:
|
||||
if self._current_role == new_role:
|
||||
logger.debug(f"Role unchanged: {new_role}")
|
||||
return
|
||||
logger.bind(user_facing=True).info(
|
||||
f"Node changing from {self._current_role} to {new_role}"
|
||||
)
|
||||
logger.info(f"Node changing from {self._current_role} to {new_role}")
|
||||
self._current_role = new_role
|
||||
await self._restart_with_role(new_role)
|
||||
|
||||
@@ -82,13 +80,13 @@ class ForwarderSupervisor:
|
||||
|
||||
# Both master and replica forward local worker events to network
|
||||
pairs.append(
|
||||
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}"
|
||||
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}"
|
||||
)
|
||||
|
||||
if role == ForwarderRole.MASTER:
|
||||
# Master: collect worker events from network into global log
|
||||
pairs.append(
|
||||
f"libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events"
|
||||
f"libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events"
|
||||
)
|
||||
# Master: broadcast global events to network
|
||||
pairs.append(
|
||||
|
||||
@@ -1,272 +1,240 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.api import start_fastapi_server
|
||||
from exo.master.election_callback import ElectionCallbacks
|
||||
from exo.master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
|
||||
from exo.master.placement import get_instance_placements, get_transition_events
|
||||
from exo.master.placement import (
|
||||
get_instance_placements_after_create,
|
||||
get_instance_placements_after_delete,
|
||||
get_transition_events,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_MASTER_LOG
|
||||
from exo.shared.db.sqlite.config import EventLogConfig
|
||||
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from exo.shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from exo.shared.keypair import Keypair, get_node_id_keypair
|
||||
from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
RequestEventLog,
|
||||
SpinUpInstance,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
Heartbeat,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceDeleted,
|
||||
TaggedEvent,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TopologyEdgeDeleted,
|
||||
TopologyNodeCreated,
|
||||
)
|
||||
from exo.shared.types.events.commands import (
|
||||
ChatCompletionCommand,
|
||||
Command,
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
TaskFinishedCommand,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
|
||||
from exo.shared.types.worker.instances import Instance
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
|
||||
class Master:
|
||||
def __init__(
|
||||
self,
|
||||
node_id_keypair: Keypair,
|
||||
node_id: NodeId,
|
||||
command_buffer: list[Command],
|
||||
global_events: AsyncSQLiteEventStorage,
|
||||
worker_events: AsyncSQLiteEventStorage,
|
||||
forwarder_binary_path: Path,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
# Ideally these would be WorkerForwarderEvents but type system says no :(
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
tb_only: bool = False,
|
||||
):
|
||||
self.state = State()
|
||||
self.node_id_keypair = node_id_keypair
|
||||
self._tg: TaskGroup | None = None
|
||||
self.node_id = node_id
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
self.worker_events = worker_events
|
||||
self.command_task_mapping: dict[CommandId, TaskId] = {}
|
||||
self.forwarder_supervisor = ForwarderSupervisor(
|
||||
self.node_id,
|
||||
forwarder_binary_path=forwarder_binary_path,
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self.election_callbacks = ElectionCallbacks(self.forwarder_supervisor)
|
||||
|
||||
@property
|
||||
def event_log_for_reads(self) -> AsyncSQLiteEventStorage:
|
||||
return self.global_events
|
||||
|
||||
@property
|
||||
def event_log_for_writes(self) -> AsyncSQLiteEventStorage:
|
||||
if self.forwarder_supervisor.current_role == ForwarderRole.MASTER:
|
||||
return self.global_events
|
||||
else:
|
||||
return self.worker_events
|
||||
|
||||
async def _get_state_snapshot(self) -> State:
|
||||
# TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events
|
||||
return State()
|
||||
|
||||
async def _run_event_loop_body(self) -> None:
|
||||
next_events: list[Event] = []
|
||||
# 1. process commands
|
||||
if (
|
||||
self.forwarder_supervisor.current_role == ForwarderRole.MASTER
|
||||
and len(self.command_buffer) > 0
|
||||
):
|
||||
# for now we do one command at a time
|
||||
next_command = self.command_buffer.pop(0)
|
||||
|
||||
logger.bind(user_facing=True).info(f"Executing command: {next_command}")
|
||||
logger.info(f"Got command: {next_command}")
|
||||
|
||||
# TODO: validate the command
|
||||
match next_command:
|
||||
case ChatCompletionCommand():
|
||||
matching_instance: Instance | None = None
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== next_command.request_params.model
|
||||
):
|
||||
matching_instance = instance
|
||||
break
|
||||
if not matching_instance:
|
||||
raise ValueError(
|
||||
f"No instance found for model {next_command.request_params.model}"
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
next_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_id=task_id,
|
||||
command_id=next_command.command_id,
|
||||
instance_id=matching_instance.instance_id,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=next_command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[next_command.command_id] = task_id
|
||||
case DeleteInstanceCommand():
|
||||
placement = get_instance_placements(
|
||||
next_command, self.state.topology, self.state.instances
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
next_events.extend(transition_events)
|
||||
case CreateInstanceCommand():
|
||||
placement = get_instance_placements(
|
||||
next_command, self.state.topology, self.state.instances
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
next_events.extend(transition_events)
|
||||
case TaskFinishedCommand():
|
||||
next_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[next_command.command_id]
|
||||
)
|
||||
)
|
||||
del self.command_task_mapping[next_command.command_id]
|
||||
|
||||
await self.event_log_for_writes.append_events(
|
||||
next_events, origin=self.node_id
|
||||
)
|
||||
# 2. get latest events
|
||||
events = await self.event_log_for_reads.get_events_since(
|
||||
self.state.last_event_applied_idx, ignore_no_op_events=True
|
||||
)
|
||||
if len(events) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
return
|
||||
|
||||
if len(events) == 1:
|
||||
logger.debug(f"Master received event: {events[0]}")
|
||||
else:
|
||||
logger.debug(f"Master received events: {events}")
|
||||
|
||||
# 3. for each event, apply it to the state
|
||||
for event_from_log in events:
|
||||
logger.trace(f"Applying event: {event_from_log}")
|
||||
self.state = apply(self.state, event_from_log)
|
||||
logger.trace(f"State: {self.state.model_dump_json()}")
|
||||
|
||||
# TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken.
|
||||
write_events: list[Event] = []
|
||||
if any(
|
||||
[
|
||||
isinstance(event_from_log.event, TopologyEdgeDeleted)
|
||||
for event_from_log in events
|
||||
]
|
||||
):
|
||||
connected_node_ids = set(
|
||||
[x.node_id for x in self.state.topology.list_nodes()]
|
||||
)
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
delete = False
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
delete = True
|
||||
break
|
||||
if delete:
|
||||
write_events.append(InstanceDeleted(instance_id=instance_id))
|
||||
|
||||
if write_events:
|
||||
await self.event_log_for_writes.append_events(
|
||||
events=write_events, origin=self.node_id
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
self.tb_only = tb_only
|
||||
|
||||
async def run(self):
|
||||
self.state = await self._get_state_snapshot()
|
||||
logger.info("Starting Master")
|
||||
|
||||
async def heartbeat_task():
|
||||
while True:
|
||||
await self.event_log_for_writes.append_events(
|
||||
[Heartbeat(node_id=self.node_id)], origin=self.node_id
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
if self._tg:
|
||||
logger.info("Stopping Master")
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.command_receiver as commands:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(
|
||||
f"Executing command: {forwarder_command.tagged_command.c}"
|
||||
)
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.tagged_command.c
|
||||
match command:
|
||||
case ChatCompletion():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
logger.warning(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
continue
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = get_instance_placements_after_delete(
|
||||
command, self.state.instances
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
placement = get_instance_placements_after_create(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
tb_only=self.tb_only,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
case SpinUpInstance():
|
||||
raise NotImplementedError
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning("Error in command processor")
|
||||
|
||||
async def _event_processor(self) -> None:
|
||||
with self.local_event_receiver as local_events:
|
||||
async for local_event in local_events:
|
||||
self._multi_buffer.ingest(
|
||||
local_event.origin_idx,
|
||||
local_event.tagged_event.c,
|
||||
local_event.origin,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
for event in self._multi_buffer.drain():
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
# TODO: SQL
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
asyncio.create_task(heartbeat_task())
|
||||
# TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken.
|
||||
if isinstance(event, TopologyEdgeDeleted):
|
||||
connected_node_ids = set(
|
||||
[x.node_id for x in self.state.topology.list_nodes()]
|
||||
)
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
await self.event_sender.send(
|
||||
InstanceDeleted(instance_id=instance_id)
|
||||
)
|
||||
break
|
||||
|
||||
# TODO: we should clean these up on shutdown
|
||||
await self.forwarder_supervisor.start_as_replica()
|
||||
if os.getenv("EXO_RUN_AS_REPLICA") in set(["TRUE", "true", "1"]):
|
||||
await self.election_callbacks.on_became_replica()
|
||||
else:
|
||||
await self.election_callbacks.on_became_master()
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
tagged_event=TaggedEvent.from_(event),
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
|
||||
role = (
|
||||
"MASTER"
|
||||
if self.forwarder_supervisor.current_role == ForwarderRole.MASTER
|
||||
else "REPLICA"
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
# Convenience method since this line is ugly
|
||||
await self.global_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=self.node_id,
|
||||
origin_idx=event.idx,
|
||||
tagged_event=TaggedEvent.from_(event.event),
|
||||
)
|
||||
)
|
||||
await self.event_log_for_writes.append_events(
|
||||
[TopologyNodeCreated(node_id=self.node_id, role=role)], origin=self.node_id
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
await self._run_event_loop_body()
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error(f"Error in _run_event_loop_body: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
async def async_main():
|
||||
node_id_keypair = get_node_id_keypair()
|
||||
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
|
||||
|
||||
event_log_manager = EventLogManager(EventLogConfig())
|
||||
await event_log_manager.initialize()
|
||||
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
|
||||
worker_events: AsyncSQLiteEventStorage = event_log_manager.worker_events
|
||||
|
||||
command_buffer: list[Command] = []
|
||||
|
||||
logger.info("Starting EXO Master")
|
||||
logger.info(f"Starting Master with node_id: {node_id}")
|
||||
|
||||
api_port = int(os.environ.get("API_PORT", 8000))
|
||||
|
||||
api_thread = threading.Thread(
|
||||
target=start_fastapi_server,
|
||||
args=(command_buffer, global_events, lambda: master.state, "0.0.0.0", api_port),
|
||||
daemon=True,
|
||||
)
|
||||
api_thread.start()
|
||||
logger.bind(user_facing=True).info(f"Dashboard started on port {api_port}.")
|
||||
|
||||
master = Master(
|
||||
node_id_keypair,
|
||||
node_id,
|
||||
command_buffer,
|
||||
global_events,
|
||||
worker_events,
|
||||
Path(os.environ["GO_BUILD_DIR"]) / "forwarder",
|
||||
)
|
||||
await master.run()
|
||||
logger_cleanup() # pyright: ignore[reportUnreachable]
|
||||
|
||||
|
||||
def main(logfile: Path = EXO_MASTER_LOG, verbosity: int = 1):
|
||||
logger_setup(logfile, verbosity)
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
import random
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Sequence
|
||||
|
||||
from exo.master.utils.placement_utils import (
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events.commands import (
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceStatus
|
||||
|
||||
@@ -25,26 +25,24 @@ def random_ephemeral_port() -> int:
|
||||
return random.randint(49152, 65535)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def get_instance_placements(
|
||||
command: CreateInstanceCommand,
|
||||
def get_instance_placements_after_create(
|
||||
command: CreateInstance,
|
||||
topology: Topology,
|
||||
current_instances: dict[InstanceId, Instance],
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
*,
|
||||
tb_only: bool = False,
|
||||
) -> dict[InstanceId, Instance]:
|
||||
available_models = [
|
||||
current_instances[instance].shard_assignments.model_id
|
||||
for instance in current_instances
|
||||
]
|
||||
if command.model_meta.model_id in available_models:
|
||||
raise ValueError(f"Instance for {command.model_meta.model_id} already exists")
|
||||
|
||||
all_nodes = list(topology.list_nodes())
|
||||
cycles = topology.get_cycles()
|
||||
from loguru import logger
|
||||
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles_tb()
|
||||
logger.info(f"{cycles=}")
|
||||
# we can also always just have a node on its own
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = cycles + singleton_cycles
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, command.model_meta.storage_size_kilobytes * 1024
|
||||
candidate_cycles, command.model_meta.storage_size
|
||||
)
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
@@ -52,25 +50,27 @@ def get_instance_placements(
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
selected_cycle = None
|
||||
|
||||
has_thunderbolt_cycle = any(
|
||||
[
|
||||
topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
for cycle in smallest_cycles
|
||||
]
|
||||
)
|
||||
if has_thunderbolt_cycle:
|
||||
smallest_cycles = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
]
|
||||
smallest_tb_cycles = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
]
|
||||
|
||||
if tb_only and smallest_tb_cycles == []:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
elif smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
|
||||
selected_cycle = max(
|
||||
smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
),
|
||||
start=Memory(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -79,8 +79,8 @@ def get_instance_placements(
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
|
||||
instance_id = command.instance_id
|
||||
target_instances = deepcopy(current_instances)
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
target_instances[instance_id] = Instance(
|
||||
instance_id=instance_id,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
@@ -88,6 +88,9 @@ def get_instance_placements(
|
||||
hosts=[
|
||||
Host(
|
||||
ip=host.ip,
|
||||
# NOTE: this is stupid
|
||||
# |
|
||||
# v
|
||||
# NOTE: it's fine to have non-deterministic ports here since this is in a command decision
|
||||
port=random_ephemeral_port(),
|
||||
)
|
||||
@@ -97,13 +100,11 @@ def get_instance_placements(
|
||||
return target_instances
|
||||
|
||||
|
||||
@get_instance_placements.register
|
||||
def _(
|
||||
command: DeleteInstanceCommand,
|
||||
topology: Topology,
|
||||
current_instances: dict[InstanceId, Instance],
|
||||
def get_instance_placements_after_delete(
|
||||
command: DeleteInstance,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
target_instances = deepcopy(current_instances)
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
if command.instance_id in target_instances:
|
||||
del target_instances[command.instance_id]
|
||||
return target_instances
|
||||
|
||||
@@ -4,9 +4,10 @@ from pydantic import BaseModel
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import Node
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.common import RunnerId
|
||||
from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
@@ -17,38 +18,41 @@ class NodeWithProfile(BaseModel):
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[Node]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
cycles: list[list[Node]], required_memory: int
|
||||
) -> list[list[Node]]:
|
||||
filtered_cycles: list[list[Node]] = []
|
||||
cycles: list[list[NodeInfo]], required_memory: Memory
|
||||
) -> list[list[NodeInfo]]:
|
||||
filtered_cycles: list[list[NodeInfo]] = []
|
||||
for cycle in cycles:
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
total_mem = sum(node.node_profile.memory.ram_available for node in cycle)
|
||||
total_mem = sum(
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(cast(list[Node], cycle))
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(cycles: list[list[Node]]) -> list[list[Node]]:
|
||||
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def get_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[Node],
|
||||
selected_cycle: list[NodeInfo],
|
||||
) -> ShardAssignments:
|
||||
if not narrow_all_nodes(selected_cycle):
|
||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
||||
|
||||
cycle_memory = sum(
|
||||
node.node_profile.memory.ram_available for node in selected_cycle
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
total_layers = model_meta.n_layers
|
||||
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
|
||||
@@ -60,7 +64,11 @@ def get_shard_assignments(
|
||||
node_layers = total_layers - layers_assigned
|
||||
else:
|
||||
node_layers = round(
|
||||
total_layers * (node.node_profile.memory.ram_available / cycle_memory)
|
||||
total_layers
|
||||
* (
|
||||
node.node_profile.memory.ram_available.in_bytes
|
||||
/ cycle_memory.in_bytes
|
||||
)
|
||||
)
|
||||
node_layers = max(1, node_layers)
|
||||
|
||||
@@ -109,6 +117,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.send_back_multiaddr is not None
|
||||
host = Host(
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
@@ -7,21 +9,21 @@ from exo.shared.types.profiling import (
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, Node
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> Node:
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return Node(
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
@@ -37,7 +39,7 @@ def create_node():
|
||||
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
@pytest.fixture
|
||||
def create_connection():
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
|
||||
def _create_connection(
|
||||
@@ -50,7 +52,6 @@ def create_connection():
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"),
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/127.0.0.1/tcp/{send_back_port}"
|
||||
),
|
||||
|
||||
@@ -23,9 +23,8 @@ from exo.shared.constants import (
|
||||
EXO_GLOBAL_EVENT_DB,
|
||||
EXO_WORKER_EVENT_DB,
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC,
|
||||
LIBP2P_WORKER_EVENTS_TOPIC,
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC,
|
||||
)
|
||||
from exo.shared.logging import logger_test_install
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
# Mock forwarder script content
|
||||
@@ -192,7 +191,6 @@ class TestForwardersupervisorBasic:
|
||||
],
|
||||
) -> None:
|
||||
"""Test starting forwarder in replica mode."""
|
||||
logger_test_install(test_logger)
|
||||
# Set environment
|
||||
os.environ.update(mock_env_vars)
|
||||
|
||||
@@ -216,7 +214,7 @@ class TestForwardersupervisorBasic:
|
||||
|
||||
# Expected replica forwarding pairs
|
||||
expected_pairs = [
|
||||
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}",
|
||||
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}",
|
||||
f"libp2p:{LIBP2P_GLOBAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events",
|
||||
]
|
||||
|
||||
@@ -238,7 +236,6 @@ class TestForwardersupervisorBasic:
|
||||
],
|
||||
) -> None:
|
||||
"""Test changing role from replica to master."""
|
||||
logger_test_install(test_logger)
|
||||
os.environ.update(mock_env_vars)
|
||||
|
||||
supervisor = ForwarderSupervisor(NodeId(), mock_forwarder_script)
|
||||
@@ -265,7 +262,7 @@ class TestForwardersupervisorBasic:
|
||||
|
||||
# Expected master forwarding pairs
|
||||
master_pairs = [
|
||||
f"libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events",
|
||||
f"libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events",
|
||||
f"sqlite:{EXO_GLOBAL_EVENT_DB}:events|libp2p:{LIBP2P_GLOBAL_EVENTS_TOPIC}",
|
||||
]
|
||||
|
||||
@@ -285,7 +282,6 @@ class TestForwardersupervisorBasic:
|
||||
],
|
||||
) -> None:
|
||||
"""Test that setting the same role twice doesn't restart the process."""
|
||||
logger_test_install(test_logger)
|
||||
os.environ.update(mock_env_vars)
|
||||
|
||||
supervisor = ForwarderSupervisor(NodeId(), mock_forwarder_script)
|
||||
@@ -316,7 +312,6 @@ class TestForwardersupervisorBasic:
|
||||
],
|
||||
) -> None:
|
||||
"""Test that Forwardersupervisor restarts the process if it crashes."""
|
||||
logger_test_install(test_logger)
|
||||
# Configure mock to exit after 1 second
|
||||
mock_env_vars["MOCK_EXIT_AFTER"] = "1"
|
||||
mock_env_vars["MOCK_EXIT_CODE"] = "1"
|
||||
@@ -365,7 +360,6 @@ class TestForwardersupervisorBasic:
|
||||
self, test_logger: logging.Logger, temp_dir: Path
|
||||
) -> None:
|
||||
"""Test behavior when forwarder binary doesn't exist."""
|
||||
logger_test_install(test_logger)
|
||||
nonexistent_path = temp_dir / "nonexistent_forwarder"
|
||||
|
||||
supervisor = ForwarderSupervisor(NodeId(), nonexistent_path)
|
||||
@@ -381,7 +375,6 @@ class TestElectionCallbacks:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_became_master(self, test_logger: logging.Logger) -> None:
|
||||
"""Test callback when becoming master."""
|
||||
logger_test_install(test_logger)
|
||||
mock_supervisor = MagicMock(spec=ForwarderSupervisor)
|
||||
mock_supervisor.notify_role_change = AsyncMock()
|
||||
|
||||
@@ -393,7 +386,6 @@ class TestElectionCallbacks:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_became_replica(self, test_logger: logging.Logger) -> None:
|
||||
"""Test callback when becoming replica."""
|
||||
logger_test_install(test_logger)
|
||||
mock_supervisor = MagicMock(spec=ForwarderSupervisor)
|
||||
mock_supervisor.notify_role_change = AsyncMock()
|
||||
|
||||
|
||||
@@ -1,30 +1,29 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import List, Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.db.sqlite.config import EventLogConfig
|
||||
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from exo.shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from exo.shared.db.config import EventLogConfig
|
||||
from exo.shared.db.connector import AsyncSQLiteEventStorage
|
||||
from exo.shared.db.event_log_manager import EventLogManager
|
||||
from exo.shared.keypair import Keypair
|
||||
from exo.shared.logging import logger_test_install
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, EventFromEventLog, Heartbeat, TaskCreated
|
||||
from exo.shared.types.events._events import (
|
||||
InstanceCreated,
|
||||
NodePerformanceMeasured,
|
||||
TopologyNodeCreated,
|
||||
)
|
||||
from exo.shared.types.events.commands import (
|
||||
ChatCompletionCommand,
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CommandId,
|
||||
CreateInstanceCommand,
|
||||
CreateInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
TopologyNodeCreated,
|
||||
)
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
@@ -35,7 +34,6 @@ from exo.shared.types.profiling import (
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceStatus,
|
||||
ShardAssignments,
|
||||
)
|
||||
@@ -43,7 +41,7 @@ from exo.shared.types.worker.shards import PartitionStrategy, PipelineShardMetad
|
||||
|
||||
|
||||
def _create_forwarder_dummy_binary() -> Path:
|
||||
path = Path(tempfile.mktemp()) / "forwarder.bin"
|
||||
path = Path(tempfile.mkstemp()[1]) / "forwarder.bin"
|
||||
if not path.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(b"#!/bin/sh\necho dummy forwarder && sleep 1000000\n")
|
||||
@@ -53,23 +51,20 @@ def _create_forwarder_dummy_binary() -> Path:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
logger = Logger(name="test_master_logger")
|
||||
logger_test_install(logger)
|
||||
event_log_manager = EventLogManager(EventLogConfig())
|
||||
await event_log_manager.initialize()
|
||||
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
async def _get_events() -> Sequence[EventFromEventLog[Event]]:
|
||||
async def _get_events() -> Sequence[IndexedEvent]:
|
||||
orig_events = await global_events.get_events_since(0)
|
||||
override_idx_in_log = 1
|
||||
events: List[EventFromEventLog[Event]] = []
|
||||
events: List[IndexedEvent] = []
|
||||
for e in orig_events:
|
||||
if isinstance(e.event, Heartbeat):
|
||||
continue
|
||||
events.append(
|
||||
EventFromEventLog(
|
||||
event=e.event, origin=e.origin, idx_in_log=override_idx_in_log
|
||||
IndexedEvent(
|
||||
event=e.event,
|
||||
idx=override_idx_in_log, # origin=e.origin,
|
||||
)
|
||||
)
|
||||
override_idx_in_log += 1
|
||||
@@ -120,9 +115,8 @@ async def test_master():
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
command_buffer.append(
|
||||
CreateInstanceCommand(
|
||||
CreateInstance(
|
||||
command_id=CommandId(),
|
||||
instance_id=InstanceId(),
|
||||
model_meta=ModelMetadata(
|
||||
model_id="llama-3.2-1b",
|
||||
pretty_name="Llama 3.2 1B",
|
||||
@@ -134,7 +128,7 @@ async def test_master():
|
||||
while len(master.state.instances.keys()) == 0:
|
||||
await asyncio.sleep(0.001)
|
||||
command_buffer.append(
|
||||
ChatCompletionCommand(
|
||||
ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
@@ -150,7 +144,7 @@ async def test_master():
|
||||
events = await _get_events()
|
||||
print(events)
|
||||
assert len(events) == 4
|
||||
assert events[0].idx_in_log == 1
|
||||
assert events[0].idx == 1
|
||||
assert isinstance(events[0].event, TopologyNodeCreated)
|
||||
assert isinstance(events[1].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[2].event, InstanceCreated)
|
||||
|
||||
@@ -2,15 +2,17 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.placement import get_instance_placements, get_transition_events
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events._events import (
|
||||
_EventType, # pyright: ignore[reportPrivateUsage]
|
||||
from exo.master.placement import (
|
||||
get_instance_placements_after_create,
|
||||
get_transition_events,
|
||||
)
|
||||
from exo.shared.types.events.commands import CreateInstanceCommand
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.topology import Connection, Node
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import CreateInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceStatus
|
||||
from exo.shared.types.worker.runners import ShardAssignments
|
||||
@@ -27,7 +29,7 @@ def instance() -> Instance:
|
||||
instance_id=InstanceId(),
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id="test-model", runner_to_shard={}, node_to_runner={}
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
hosts=[],
|
||||
)
|
||||
@@ -36,18 +38,17 @@ def instance() -> Instance:
|
||||
@pytest.fixture
|
||||
def model_meta() -> ModelMetadata:
|
||||
return ModelMetadata(
|
||||
model_id="test-model",
|
||||
storage_size_kilobytes=1000,
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
)
|
||||
|
||||
|
||||
def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand:
|
||||
return CreateInstanceCommand(
|
||||
def create_instance_command(model_meta: ModelMetadata) -> CreateInstance:
|
||||
return CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta,
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
|
||||
|
||||
@@ -65,32 +66,33 @@ def test_get_instance_placements_create_instance(
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[Memory, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size_kilobytes = sum(
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
|
||||
create_instance_command = CreateInstanceCommand(
|
||||
create_instance_command = CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta,
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
topology.add_node(create_node(available_memory[0] * 1024, node_id_a))
|
||||
topology.add_node(create_node(available_memory[1] * 1024, node_id_b))
|
||||
topology.add_node(create_node(available_memory[2] * 1024, node_id_c))
|
||||
topology.add_node(create_node(Memory.from_bytes(available_memory[0]), node_id_a))
|
||||
topology.add_node(create_node(Memory.from_bytes(available_memory[1]), node_id_b))
|
||||
topology.add_node(create_node(Memory.from_bytes(available_memory[2]), node_id_c))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
|
||||
# act
|
||||
placements = get_instance_placements(create_instance_command, topology, {})
|
||||
placements = get_instance_placements_after_create(
|
||||
create_instance_command, topology, {}
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
@@ -117,22 +119,23 @@ def test_get_instance_placements_create_instance(
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
create_instance_command = CreateInstanceCommand(
|
||||
create_instance_command = CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=ModelMetadata(
|
||||
model_id="test-model",
|
||||
storage_size_kilobytes=1000,
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
),
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
placements = get_instance_placements(create_instance_command, topology, {})
|
||||
placements = get_instance_placements_after_create(
|
||||
create_instance_command, topology, {}
|
||||
)
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -144,22 +147,23 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
create_instance_command = CreateInstanceCommand(
|
||||
create_instance_command = CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=ModelMetadata(
|
||||
model_id="test-model",
|
||||
storage_size_kilobytes=1000,
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
),
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
placements = get_instance_placements(create_instance_command, topology, {})
|
||||
placements = get_instance_placements_after_create(
|
||||
create_instance_command, topology, {}
|
||||
)
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -171,27 +175,26 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
create_instance_command = CreateInstanceCommand(
|
||||
create_instance_command = CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=ModelMetadata(
|
||||
model_id="test-model",
|
||||
storage_size_kilobytes=1001,
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
),
|
||||
instance_id=InstanceId(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||
get_instance_placements(create_instance_command, topology, {})
|
||||
get_instance_placements_after_create(create_instance_command, topology, {})
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(topology: Topology, instance: Instance):
|
||||
def test_get_transition_events_no_change(instance: Instance):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances = {instance_id: instance}
|
||||
@@ -204,7 +207,7 @@ def test_get_transition_events_no_change(topology: Topology, instance: Instance)
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
def test_get_transition_events_create_instance(topology: Topology, instance: Instance):
|
||||
def test_get_transition_events_create_instance(instance: Instance):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {}
|
||||
@@ -215,10 +218,10 @@ def test_get_transition_events_create_instance(topology: Topology, instance: Ins
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == _EventType.InstanceCreated
|
||||
assert isinstance(events[0], InstanceCreated)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance(topology: Topology, instance: Instance):
|
||||
def test_get_transition_events_delete_instance(instance: Instance):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
@@ -229,5 +232,5 @@ def test_get_transition_events_delete_instance(topology: Topology, instance: Ins
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == _EventType.InstanceDeleted
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.utils.placement_utils import (
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_shard_assignments,
|
||||
@@ -11,8 +10,9 @@ from exo.master.utils.placement_utils import (
|
||||
)
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.topology import Connection, Node
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -23,7 +23,7 @@ def topology() -> Topology:
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
@@ -47,7 +47,7 @@ def test_filter_cycles_by_memory(
|
||||
assert len(cycles[0]) == 2
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, 1)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -57,7 +57,7 @@ def test_filter_cycles_by_memory(
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
@@ -77,7 +77,9 @@ def test_filter_cycles_by_insufficient_memory(
|
||||
topology.add_connection(connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(topology.get_cycles(), 2001 * 1024)
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
topology.get_cycles(), Memory.from_kb(2001)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
@@ -85,7 +87,7 @@ def test_filter_cycles_by_insufficient_memory(
|
||||
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
@@ -110,7 +112,7 @@ def test_filter_multiple_cycles_by_memory(
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, 1500 * 1024)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -124,7 +126,7 @@ def test_filter_multiple_cycles_by_memory(
|
||||
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
@@ -164,7 +166,7 @@ def test_get_smallest_cycles(
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
@@ -189,10 +191,10 @@ def test_get_shard_assignments(
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id="test-model",
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size_kilobytes=1000,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
@@ -223,7 +225,7 @@ def test_get_shard_assignments(
|
||||
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], Node],
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
@@ -250,9 +252,9 @@ def test_get_hosts_from_subgraph(
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip=IPv4Address("127.0.0.1"), port=5001),
|
||||
Host(ip=IPv4Address("127.0.0.1"), port=5002),
|
||||
Host(ip=IPv4Address("127.0.0.1"), port=5003),
|
||||
Host(ip=("127.0.0.1"), port=5001),
|
||||
Host(ip=("127.0.0.1"), port=5002),
|
||||
Host(ip=("127.0.0.1"), port=5003),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
@@ -7,7 +7,7 @@ from exo.shared.types.profiling import (
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, Node, NodeId
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,7 +20,6 @@ def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
@@ -30,7 +29,7 @@ def connection() -> Connection:
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryPerformanceProfile(
|
||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
)
|
||||
system_profile = SystemPerformanceProfile(flops_fp16=1000)
|
||||
@@ -54,7 +53,7 @@ def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
node_id = NodeId()
|
||||
|
||||
# act
|
||||
topology.add_node(Node(node_id=node_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(node_id)
|
||||
@@ -65,9 +64,11 @@ def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
@@ -82,9 +83,11 @@ def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
@@ -92,7 +95,7 @@ def test_update_node_profile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
@@ -113,9 +116,11 @@ def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
@@ -125,7 +130,6 @@ def test_update_connection_profile(
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
local_multiaddr=connection.local_multiaddr,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
@@ -142,9 +146,11 @@ def test_remove_connection_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
@@ -155,64 +161,15 @@ def test_remove_connection_still_connected(
|
||||
assert topology.get_connection_profile(connection) is None
|
||||
|
||||
|
||||
def test_remove_connection_bridge(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
"""Create a bridge scenario: master -> node_a -> node_b
|
||||
and remove the bridge connection (master -> node_a)"""
|
||||
# arrange
|
||||
master_id = NodeId()
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
|
||||
topology.add_node(Node(node_id=master_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=node_a_id, node_profile=node_profile))
|
||||
topology.add_node(Node(node_id=node_b_id, node_profile=node_profile))
|
||||
|
||||
topology.set_master_node_id(master_id)
|
||||
|
||||
connection_master_to_a = Connection(
|
||||
local_node_id=master_id,
|
||||
send_back_node_id=node_a_id,
|
||||
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
connection_a_to_b = Connection(
|
||||
local_node_id=node_a_id,
|
||||
send_back_node_id=node_b_id,
|
||||
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1236"),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1237"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
topology.add_connection(connection_master_to_a)
|
||||
topology.add_connection(connection_a_to_b)
|
||||
|
||||
assert len(list(topology.list_nodes())) == 3
|
||||
|
||||
topology.remove_connection(connection_master_to_a)
|
||||
|
||||
remaining_nodes = list(topology.list_nodes())
|
||||
assert len(remaining_nodes) == 1
|
||||
assert remaining_nodes[0].node_id == master_id
|
||||
|
||||
assert topology.get_node_profile(node_a_id) is None
|
||||
assert topology.get_node_profile(node_b_id) is None
|
||||
|
||||
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
@@ -227,9 +184,11 @@ def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
@@ -238,7 +197,7 @@ def test_list_nodes(
|
||||
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, Node) for node in nodes)
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
|
||||
0
src/exo/routing/__init__.py
Normal file
0
src/exo/routing/__init__.py
Normal file
37
src/exo/routing/connection_message.py
Normal file
37
src/exo/routing/connection_message.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
)
|
||||
242
src/exo/routing/router.py
Normal file
242
src/exo/routing/router.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from copy import copy
|
||||
from itertools import count
|
||||
from math import inf
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
from .connection_message import ConnectionMessage
|
||||
from .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic
|
||||
|
||||
|
||||
# A significant current limitation of the TopicRouter is that it is not capable
|
||||
# of preventing feedback, as it does not ask for a system id so cannot tell
|
||||
# which message is coming/going to which system.
|
||||
# This is currently only relevant for Election
|
||||
class TopicRouter[T: CamelCaseModel]:
|
||||
def __init__(
|
||||
self,
|
||||
topic: TypedTopic[T],
|
||||
networking_sender: Sender[tuple[str, bytes]],
|
||||
max_buffer_size: float = inf,
|
||||
):
|
||||
self.topic: TypedTopic[T] = topic
|
||||
self.senders: set[Sender[T]] = set()
|
||||
send, recv = channel[T]()
|
||||
self.receiver: Receiver[T] = recv
|
||||
self.temp_sender: Sender[T] | None = send
|
||||
self.networking_sender: Sender[tuple[str, bytes]] = networking_sender
|
||||
|
||||
async def run(self):
|
||||
logger.debug(f"Topic Router {self.topic} ready to send")
|
||||
with self.receiver as items:
|
||||
async for item in items:
|
||||
# Check if we should send to network
|
||||
if (
|
||||
len(self.senders) == 0
|
||||
and self.topic.publish_policy is PublishPolicy.Minimal
|
||||
):
|
||||
await self._send_out(item)
|
||||
continue
|
||||
if self.topic.publish_policy is PublishPolicy.Always:
|
||||
await self._send_out(item)
|
||||
# Then publish to all senders
|
||||
await self.publish(item)
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug(f"Shutting down Topic Router {self.topic}")
|
||||
# Close all the things!
|
||||
for sender in self.senders:
|
||||
sender.close()
|
||||
if self.temp_sender:
|
||||
self.temp_sender.close()
|
||||
self.receiver.close()
|
||||
|
||||
async def publish(self, item: T):
|
||||
"""
|
||||
Publish item T on this topic to all senders.
|
||||
NB: this sends to ALL receivers, potentially including receivers held by the object doing the sending.
|
||||
You should handle your own output if you hold a sender + receiver pair.
|
||||
"""
|
||||
to_clear: set[Sender[T]] = set()
|
||||
for sender in copy(self.senders):
|
||||
try:
|
||||
await sender.send(item)
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
to_clear.add(sender)
|
||||
self.senders -= to_clear
|
||||
|
||||
async def publish_bytes(self, data: bytes):
|
||||
await self.publish(self.topic.deserialize(data))
|
||||
|
||||
async def _send_out(self, item: T):
|
||||
logger.trace(f"TopicRouter {self.topic.topic} sending {item}")
|
||||
await self.networking_sender.send(
|
||||
(str(self.topic.topic), self.topic.serialize(item))
|
||||
)
|
||||
|
||||
|
||||
class Router:
|
||||
@classmethod
|
||||
def create(cls, identity: Keypair) -> "Router":
|
||||
return cls(handle=NetworkingHandle(identity))
|
||||
|
||||
def __init__(self, handle: NetworkingHandle):
|
||||
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
|
||||
send, recv = channel[tuple[str, bytes]]()
|
||||
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
|
||||
self._net: NetworkingHandle = handle
|
||||
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
|
||||
self._id_count = count()
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
||||
assert self._tg is None, "Attempted to register topic after setup time"
|
||||
send = self._tmp_networking_sender
|
||||
if send:
|
||||
self._tmp_networking_sender = None
|
||||
else:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
|
||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
# There's gotta be a way to do this without THIS many asserts
|
||||
assert router is not None
|
||||
assert router.topic == topic
|
||||
send: Sender[T] | None = cast(Sender[T] | None, router.temp_sender)
|
||||
if send:
|
||||
router.temp_sender = None
|
||||
return send
|
||||
try:
|
||||
sender = cast(Receiver[T], router.receiver).clone_sender()
|
||||
except ClosedResourceError:
|
||||
sender, router.receiver = cast(
|
||||
tuple[Sender[T], Receiver[CamelCaseModel]], channel[T]()
|
||||
)
|
||||
return sender
|
||||
|
||||
def receiver[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Receiver[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
# There's gotta be a way to do this without THIS many asserts
|
||||
|
||||
assert router is not None
|
||||
assert router.topic == topic
|
||||
assert router.topic.model_type == topic.model_type
|
||||
|
||||
send, recv = channel[T]()
|
||||
router.senders.add(cast(Sender[CamelCaseModel], send))
|
||||
|
||||
return recv
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
if not self._tg:
|
||||
return
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(f"Received message on unknown or inactive topic {topic}")
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
|
||||
async def _networking_publish(self):
|
||||
# This with/for pattern ensures this method doesn't return until after the receiver closes
|
||||
# This is good for safety, but is mostly a redundant check.
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
try:
|
||||
logger.trace(f"Sending message on {topic} with payload {data}")
|
||||
await self._net.gossipsub_publish(topic, data)
|
||||
except NoPeersSubscribedToTopicError:
|
||||
logger.trace(f"Failed to send over {topic} - No peers found.")
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
path: str | bytes | PathLike[str] | PathLike[bytes] = EXO_NODE_ID_KEYPAIR,
|
||||
) -> Keypair:
|
||||
"""
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
# operate with cross-process lock to avoid race conditions
|
||||
with FileLock(lock_path(path)):
|
||||
with open(path, "a+b") as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
return keypair
|
||||
141
src/exo/routing/tests/test_event_buffer.py
Normal file
141
src/exo/routing/tests/test_event_buffer.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.events import Event, TestEvent
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
def make_indexed_event(idx: int) -> tuple[int, Event]:
|
||||
"""Factory function to create a unique ForwarderEvent for a given index."""
|
||||
return (idx, TestEvent())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def buffer() -> OrderedBuffer[Event]:
|
||||
"""Provides a clean instance of OrderedBuffer[Event] for each test."""
|
||||
return OrderedBuffer[Event]()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_state(buffer: OrderedBuffer[Event]):
|
||||
"""Tests that a new buffer is empty and starts at index 1."""
|
||||
assert buffer.next_idx_to_release == 0
|
||||
assert not buffer.store
|
||||
assert buffer.drain() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]):
|
||||
"""Tests ingesting and draining a simple, ordered sequence of events."""
|
||||
events = [make_indexed_event(0), make_indexed_event(1), make_indexed_event(2)]
|
||||
[buffer.ingest(*ev) for ev in events]
|
||||
|
||||
drained_events = buffer.drain_indexed()
|
||||
assert drained_events == events
|
||||
assert buffer.next_idx_to_release == 3
|
||||
assert not buffer.store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]):
|
||||
"""Tests that out-of-order events are buffered and drained in the correct sequence."""
|
||||
event1 = make_indexed_event(0)
|
||||
event2 = make_indexed_event(1)
|
||||
event3 = make_indexed_event(2)
|
||||
|
||||
buffer.ingest(*event3)
|
||||
buffer.ingest(*event1)
|
||||
buffer.ingest(*event2)
|
||||
|
||||
drained_events = buffer.drain_indexed()
|
||||
assert drained_events == [event1, event2, event3]
|
||||
assert buffer.next_idx_to_release == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]):
|
||||
"""Tests that draining stops when there is a gap in the event indices."""
|
||||
event1 = make_indexed_event(0)
|
||||
event3 = make_indexed_event(2)
|
||||
|
||||
buffer.ingest(*event1)
|
||||
buffer.ingest(*event3)
|
||||
|
||||
drained_events = buffer.drain_indexed()
|
||||
assert drained_events == [event1]
|
||||
assert buffer.next_idx_to_release == 1
|
||||
|
||||
assert buffer.drain() == []
|
||||
assert 2 in buffer.store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]):
|
||||
"""Tests that once a gap is filled, the rest of the sequence is drained."""
|
||||
event0 = make_indexed_event(0)
|
||||
event2 = make_indexed_event(2)
|
||||
buffer.ingest(*event0)
|
||||
buffer.ingest(*event2)
|
||||
|
||||
buffer.drain()
|
||||
assert buffer.next_idx_to_release == 1
|
||||
|
||||
event1 = make_indexed_event(1)
|
||||
buffer.ingest(*event1)
|
||||
|
||||
drained_events = buffer.drain_indexed()
|
||||
assert [e[0] for e in drained_events] == [1, 2]
|
||||
assert buffer.next_idx_to_release == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
|
||||
"""Tests that if multiple events for the same index are ingested, the first one wins."""
|
||||
event2_first = make_indexed_event(1)
|
||||
event2_second = (1, TestEvent())
|
||||
|
||||
buffer.ingest(*make_indexed_event(0))
|
||||
buffer.ingest(*event2_first)
|
||||
buffer.ingest(*event2_second) # This duplicate should be ignored
|
||||
|
||||
drained = buffer.drain_indexed()
|
||||
assert len(drained) == 2
|
||||
|
||||
assert drained[1][1].event_id == event2_first[1].event_id
|
||||
assert drained[1][1].event_id != event2_second[1].event_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
|
||||
"""Tests that events with an index lower than next_idx_to_release are dropped."""
|
||||
buffer.ingest(*make_indexed_event(0))
|
||||
buffer.ingest(*make_indexed_event(1))
|
||||
buffer.drain()
|
||||
|
||||
assert buffer.next_idx_to_release == 2
|
||||
|
||||
stale_event1 = make_indexed_event(0)
|
||||
stale_event2 = make_indexed_event(1)
|
||||
buffer.ingest(*stale_event1)
|
||||
buffer.ingest(*stale_event2)
|
||||
|
||||
assert not buffer.store
|
||||
assert buffer.drain() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]):
|
||||
"""Tests reusing the buffer after it has been fully drained."""
|
||||
buffer.ingest(*make_indexed_event(0))
|
||||
buffer.ingest(*make_indexed_event(1))
|
||||
buffer.drain()
|
||||
|
||||
assert buffer.next_idx_to_release == 2
|
||||
assert not buffer.store
|
||||
|
||||
buffer.ingest(*make_indexed_event(4))
|
||||
buffer.ingest(*make_indexed_event(2))
|
||||
|
||||
drained = buffer.drain_indexed()
|
||||
assert [e[0] for e in drained] == [2]
|
||||
assert buffer.next_idx_to_release == 3
|
||||
assert 4 in buffer.store
|
||||
47
src/exo/routing/topics.py
Normal file
47
src/exo/routing/topics.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.types.commands import ForwarderCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class PublishPolicy(str, Enum):
|
||||
Never = "Never"
|
||||
"""Never publish to the network - this is a local message"""
|
||||
Minimal = "Minimal"
|
||||
"""Only publish when there is no local receiver for this type of message"""
|
||||
Always = "Always"
|
||||
"""Always publish to the network"""
|
||||
|
||||
|
||||
@dataclass # (frozen=True)
|
||||
class TypedTopic[T: CamelCaseModel]:
|
||||
topic: str
|
||||
publish_policy: PublishPolicy
|
||||
|
||||
model_type: type[
|
||||
T
|
||||
] # This can be worked around with evil type hacking, see https://stackoverflow.com/a/71720366 - I don't think it's necessary here.
|
||||
|
||||
@staticmethod
|
||||
def serialize(t: T) -> bytes:
|
||||
return t.model_dump_json().encode("utf-8")
|
||||
|
||||
def deserialize(self, b: bytes) -> T:
|
||||
return self.model_type.model_validate_json(b.decode("utf-8"))
|
||||
|
||||
|
||||
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
|
||||
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
|
||||
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
|
||||
ELECTION_MESSAGES = TypedTopic(
|
||||
"election_messages", PublishPolicy.Always, ElectionMessage
|
||||
)
|
||||
CONNECTION_MESSAGES = TypedTopic(
|
||||
"connection_messages", PublishPolicy.Never, ConnectionMessage
|
||||
)
|
||||
@@ -1,18 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from functools import singledispatch
|
||||
from typing import Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
EventFromEventLog,
|
||||
IndexedEvent,
|
||||
InstanceActivated,
|
||||
InstanceCreated,
|
||||
InstanceDeactivated,
|
||||
InstanceDeleted,
|
||||
InstanceReplacedAtomically,
|
||||
NodePerformanceMeasured,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -20,48 +19,74 @@ from exo.shared.types.events import (
|
||||
TaskDeleted,
|
||||
TaskFailed,
|
||||
TaskStateUpdated,
|
||||
TestEvent,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
TopologyEdgeReplacedAtomically,
|
||||
TopologyNodeCreated,
|
||||
WorkerStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import Connection, Node
|
||||
from exo.shared.types.worker.common import NodeStatus, RunnerId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.common import RunnerId, WorkerStatus
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceStatus
|
||||
from exo.shared.types.worker.runners import RunnerStatus
|
||||
|
||||
|
||||
@singledispatch
|
||||
def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to *state*.
|
||||
|
||||
Events decorated with ``@no_op_event`` set ``__no_apply__ = True`` on the
|
||||
class. Such events are considered *no-ops* and therefore leave the state
|
||||
unchanged without requiring a dedicated handler in this dispatch table.
|
||||
"""
|
||||
|
||||
if getattr(event, "__no_apply__", False):
|
||||
return state
|
||||
|
||||
raise RuntimeError(f"no handler registered for event type {type(event).__name__}")
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case TestEvent() | ChunkGenerated():
|
||||
return state
|
||||
case InstanceActivated():
|
||||
return apply_instance_activated(event, state)
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeactivated():
|
||||
return apply_instance_deactivated(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case RunnerDeleted():
|
||||
return apply_runner_deleted(event, state)
|
||||
case RunnerStatusUpdated():
|
||||
return apply_runner_status_updated(event, state)
|
||||
case TaskCreated():
|
||||
return apply_task_created(event, state)
|
||||
case TaskDeleted():
|
||||
return apply_task_deleted(event, state)
|
||||
case TaskFailed():
|
||||
return apply_task_failed(event, state)
|
||||
case TaskStateUpdated():
|
||||
return apply_task_state_updated(event, state)
|
||||
case WorkerStatusUpdated():
|
||||
return apply_worker_status_updated(event, state)
|
||||
case TopologyNodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case TopologyEdgeCreated():
|
||||
return apply_topology_edge_created(event, state)
|
||||
case TopologyEdgeDeleted():
|
||||
return apply_topology_edge_deleted(event, state)
|
||||
|
||||
|
||||
def apply(state: State, event: EventFromEventLog[Event]) -> State:
|
||||
def apply(state: State, event: IndexedEvent) -> State:
|
||||
# Just to test that events are only applied in correct order
|
||||
if state.last_event_applied_idx != event.idx - 1:
|
||||
logger.warning(
|
||||
f"Expected event {state.last_event_applied_idx + 1} but received {event.idx}"
|
||||
)
|
||||
assert state.last_event_applied_idx == event.idx - 1
|
||||
new_state: State = event_apply(event.event, state)
|
||||
return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log})
|
||||
return new_state.model_copy(update={"last_event_applied_idx": event.idx})
|
||||
|
||||
|
||||
@event_apply.register(TaskCreated)
|
||||
def apply_task_created(event: TaskCreated, state: State) -> State:
|
||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
@event_apply.register(TaskDeleted)
|
||||
def apply_task_deleted(event: TaskDeleted, state: State) -> State:
|
||||
new_tasks: Mapping[TaskId, Task] = {
|
||||
tid: task for tid, task in state.tasks.items() if tid != event.task_id
|
||||
@@ -69,7 +94,6 @@ def apply_task_deleted(event: TaskDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
@event_apply.register(TaskStateUpdated)
|
||||
def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
|
||||
if event.task_id not in state.tasks:
|
||||
return state
|
||||
@@ -86,7 +110,6 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
@event_apply.register(TaskFailed)
|
||||
def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
if event.task_id not in state.tasks:
|
||||
return state
|
||||
@@ -98,7 +121,6 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
@event_apply.register(InstanceCreated)
|
||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
instance = event.instance
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
@@ -108,13 +130,12 @@ def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
@event_apply.register(InstanceActivated)
|
||||
def apply_instance_activated(event: InstanceActivated, state: State) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(
|
||||
update={"type": InstanceStatus.ACTIVE}
|
||||
update={"instance_type": InstanceStatus.ACTIVE}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
@@ -123,13 +144,12 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
@event_apply.register(InstanceDeactivated)
|
||||
def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(
|
||||
update={"type": InstanceStatus.INACTIVE}
|
||||
update={"instance_type": InstanceStatus.INACTIVE}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
@@ -138,7 +158,6 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
@event_apply.register(InstanceDeleted)
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
@@ -146,19 +165,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
@event_apply.register(InstanceReplacedAtomically)
|
||||
def apply_instance_replaced_atomically(
|
||||
event: InstanceReplacedAtomically, state: State
|
||||
) -> State:
|
||||
new_instances = dict(state.instances)
|
||||
if event.instance_to_replace in new_instances:
|
||||
del new_instances[event.instance_to_replace]
|
||||
if event.new_instance_id in state.instances:
|
||||
new_instances[event.new_instance_id] = state.instances[event.new_instance_id]
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
@event_apply.register(RunnerStatusUpdated)
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
**state.runners,
|
||||
@@ -167,7 +173,6 @@ def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> Sta
|
||||
return state.model_copy(update={"runners": new_runners})
|
||||
|
||||
|
||||
@event_apply.register(RunnerDeleted)
|
||||
def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs for rid, rs in state.runners.items() if rid != event.runner_id
|
||||
@@ -175,7 +180,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"runners": new_runners})
|
||||
|
||||
|
||||
@event_apply.register(NodePerformanceMeasured)
|
||||
# TODO: This whole function needs fixing
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
@@ -187,58 +192,39 @@ def apply_node_performance_measured(
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_node(event.node_id):
|
||||
# TODO: figure out why this is happening in the first place
|
||||
topology.add_node(Node(node_id=event.node_id))
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
@event_apply.register(WorkerStatusUpdated)
|
||||
def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> State:
|
||||
new_node_status: Mapping[NodeId, NodeStatus] = {
|
||||
new_node_status: Mapping[NodeId, WorkerStatus] = {
|
||||
**state.node_status,
|
||||
event.node_id: event.node_state,
|
||||
}
|
||||
return state.model_copy(update={"node_status": new_node_status})
|
||||
|
||||
|
||||
@event_apply.register(TopologyNodeCreated)
|
||||
def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_node(Node(node_id=event.node_id))
|
||||
if event.role == "MASTER":
|
||||
topology.set_master_node_id(event.node_id)
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
@event_apply.register(TopologyEdgeCreated)
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
@event_apply.register(TopologyEdgeReplacedAtomically)
|
||||
def apply_topology_edge_replaced_atomically(
|
||||
event: TopologyEdgeReplacedAtomically, state: State
|
||||
) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.update_connection_profile(event.edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
@event_apply.register(TopologyEdgeDeleted)
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
opposite_edge = Connection(
|
||||
local_node_id=event.edge.send_back_node_id,
|
||||
send_back_node_id=event.edge.local_node_id,
|
||||
local_multiaddr=event.edge.send_back_multiaddr,
|
||||
send_back_multiaddr=event.edge.local_multiaddr,
|
||||
)
|
||||
if not topology.contains_connection(opposite_edge):
|
||||
return state.model_copy(update={"topology": topology})
|
||||
topology.remove_connection(opposite_edge)
|
||||
if not topology.contains_connection(event.edge) and topology.contains_connection(
|
||||
event.edge.reverse()
|
||||
):
|
||||
topology.remove_connection(event.edge.reverse())
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
@@ -1,3 +0,0 @@
|
||||
from .apply import apply
|
||||
|
||||
__all__ = ["apply"]
|
||||
@@ -21,8 +21,10 @@ EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
|
||||
EXO_IPC_DIR = EXO_HOME / "ipc"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
LIBP2P_WORKER_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
|
||||
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
|
||||
LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
|
||||
# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip.
|
||||
LB_TFLOPS = 2.3
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Database implementations for event storage."""
|
||||
|
||||
from .sqlite import AsyncSQLiteEventStorage, EventStorageProtocol
|
||||
|
||||
__all__ = ["AsyncSQLiteEventStorage", "EventStorageProtocol"]
|
||||
|
||||
19
src/exo/shared/db/config.py
Normal file
19
src/exo/shared/db/config.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.constants import EXO_GLOBAL_EVENT_DB
|
||||
|
||||
|
||||
class EventLogConfig(BaseModel):
|
||||
"""Configuration for the event log system"""
|
||||
|
||||
# Batch processing settings
|
||||
batch_size: int = 100
|
||||
batch_timeout_ms: int = 100
|
||||
debounce_ms: int = 10
|
||||
max_age_ms: int = 100
|
||||
|
||||
def get_db_path(self) -> Path:
|
||||
"""Get the full path for a specific event log type"""
|
||||
return EXO_GLOBAL_EVENT_DB
|
||||
@@ -8,13 +8,13 @@ from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, create_async_engine
|
||||
|
||||
from exo.shared.types.events import Event, EventParser, NodeId
|
||||
from exo.shared.types.events._events import Heartbeat
|
||||
from exo.shared.types.events.components import EventFromEventLog
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, IndexedEvent, event_tag
|
||||
|
||||
from .types import StoredEvent
|
||||
|
||||
@@ -73,9 +73,7 @@ class AsyncSQLiteEventStorage:
|
||||
for event in events:
|
||||
await self._write_queue.put((event, origin))
|
||||
|
||||
async def get_events_since(
|
||||
self, last_idx: int, ignore_no_op_events: bool = False
|
||||
) -> Sequence[EventFromEventLog[Event]]:
|
||||
async def get_events_since(self, last_idx: int) -> Sequence[IndexedEvent]:
|
||||
"""Retrieve events after a specific index."""
|
||||
if self._closed:
|
||||
raise RuntimeError("Storage is closed")
|
||||
@@ -92,10 +90,10 @@ class AsyncSQLiteEventStorage:
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
events: list[EventFromEventLog[Event]] = []
|
||||
events: list[IndexedEvent] = []
|
||||
for row in rows:
|
||||
rowid: int = cast(int, row[0])
|
||||
origin: str = cast(str, row[1])
|
||||
# origin: str = cast(str, row[1])
|
||||
# Parse JSON string to dict
|
||||
raw_event_data = row[2] # type: ignore[reportAny] - SQLAlchemy result is Any
|
||||
if isinstance(raw_event_data, str):
|
||||
@@ -104,14 +102,12 @@ class AsyncSQLiteEventStorage:
|
||||
)
|
||||
else:
|
||||
event_data = cast(dict[str, Any], raw_event_data)
|
||||
event = EventParser.validate_python(event_data)
|
||||
if ignore_no_op_events and event.__no_apply__:
|
||||
continue
|
||||
event: Event = TypeAdapter(Event).validate_python(event_data) # type: ignore
|
||||
events.append(
|
||||
EventFromEventLog(
|
||||
event=event,
|
||||
origin=NodeId(origin),
|
||||
idx_in_log=rowid, # rowid becomes idx_in_log
|
||||
IndexedEvent(
|
||||
event=event, # type: ignore
|
||||
# origin=NodeId(origin),
|
||||
idx=rowid, # rowid becomes idx_in_log
|
||||
)
|
||||
)
|
||||
|
||||
@@ -325,7 +321,7 @@ class AsyncSQLiteEventStorage:
|
||||
for event, origin in batch:
|
||||
stored_event = StoredEvent(
|
||||
origin=origin,
|
||||
event_type=event.event_type,
|
||||
event_type=event_tag(event),
|
||||
event_id=str(event.event_id),
|
||||
event_data=event.model_dump(
|
||||
mode="json"
|
||||
@@ -334,8 +330,7 @@ class AsyncSQLiteEventStorage:
|
||||
session.add(stored_event)
|
||||
|
||||
await session.commit()
|
||||
if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0:
|
||||
logger.debug(f"Committed batch of {len(batch)} events")
|
||||
logger.debug(f"Committed batch of {len(batch)} events")
|
||||
|
||||
except OperationalError as e:
|
||||
if "database is locked" in str(e):
|
||||
@@ -393,7 +388,7 @@ class AsyncSQLiteEventStorage:
|
||||
for event, origin in batch:
|
||||
stored_event = StoredEvent(
|
||||
origin=origin,
|
||||
event_type=event.event_type,
|
||||
event_type=event_tag(event),
|
||||
event_id=str(event.event_id),
|
||||
event_data=event.model_dump(mode="json"),
|
||||
)
|
||||
@@ -401,10 +396,9 @@ class AsyncSQLiteEventStorage:
|
||||
|
||||
await session.commit()
|
||||
|
||||
if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0:
|
||||
logger.debug(
|
||||
f"Committed batch of {len(batch)} events after {retry_count} retries"
|
||||
)
|
||||
logger.debug(
|
||||
f"Committed batch of {len(batch)} events after {retry_count} retries"
|
||||
)
|
||||
return
|
||||
|
||||
except OperationalError as e:
|
||||
110
src/exo/shared/db/event_log_manager.py
Normal file
110
src/exo/shared/db/event_log_manager.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from exo.shared.constants import EXO_HOME
|
||||
from exo.shared.db.config import EventLogConfig
|
||||
from exo.shared.db.connector import AsyncSQLiteEventStorage
|
||||
from exo.utils.fs import ensure_directory_exists
|
||||
|
||||
|
||||
class EventLogManager:
|
||||
"""
|
||||
Manages both worker and global event log connectors.
|
||||
Used by both master and worker processes with different access patterns:
|
||||
|
||||
- Worker: writes to worker_events, tails global_events
|
||||
- Master (elected): writes to global_events, tails global_events
|
||||
- Master (replica): writes to worker_events, tails global_events
|
||||
"""
|
||||
|
||||
def __init__(self, config: EventLogConfig):
|
||||
self._config = config
|
||||
self._connector: AsyncSQLiteEventStorage | None = None
|
||||
|
||||
# Ensure base directory exists
|
||||
ensure_directory_exists(EXO_HOME)
|
||||
|
||||
# TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor.
|
||||
async def initialize(self, max_retries: int = 3) -> None:
|
||||
"""Initialize both connectors with retry logic - call this during startup"""
|
||||
# Both master and worker need both connectors
|
||||
retry_count: int = 0
|
||||
last_error: Exception | None = None
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self.get_connector()
|
||||
break
|
||||
except OperationalError as e:
|
||||
last_error = e
|
||||
if "database is locked" in str(e) and retry_count < max_retries - 1:
|
||||
retry_count += 1
|
||||
delay = cast(float, 0.5 * (2**retry_count))
|
||||
logger.warning(
|
||||
f"Database locked while initializing db, retry {retry_count}/{max_retries} after {delay}s"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.opt(exception=e).error(
|
||||
f"Failed to initialize db after {retry_count + 1} attempts"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Could not initialize db after {retry_count + 1} attempts"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("Unexpected error initializing db")
|
||||
raise
|
||||
|
||||
if retry_count >= max_retries and last_error:
|
||||
raise RuntimeError(
|
||||
f"Could not initialize db after {max_retries} attempts"
|
||||
) from last_error
|
||||
logger.bind(user_facing=True).info("Initialized all event log connectors")
|
||||
|
||||
async def get_connector(self) -> AsyncSQLiteEventStorage:
|
||||
"""Get or create a connector for the specified log type"""
|
||||
if not self._connector:
|
||||
db_path = self._config.get_db_path()
|
||||
|
||||
try:
|
||||
connector = AsyncSQLiteEventStorage(
|
||||
db_path=db_path,
|
||||
batch_size=self._config.batch_size,
|
||||
batch_timeout_ms=self._config.batch_timeout_ms,
|
||||
debounce_ms=self._config.debounce_ms,
|
||||
max_age_ms=self._config.max_age_ms,
|
||||
)
|
||||
|
||||
# Start the connector (creates tables if needed)
|
||||
await connector.start()
|
||||
|
||||
self._connector = connector
|
||||
logger.bind(user_facing=True).info(
|
||||
f"Initialized db connector at {db_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.bind(user_facing=True).opt(exception=e).error(
|
||||
"Failed to create db connector"
|
||||
)
|
||||
raise
|
||||
|
||||
return self._connector
|
||||
|
||||
@property
|
||||
def events(self) -> AsyncSQLiteEventStorage:
|
||||
"""Access event log (must call initialize() first)"""
|
||||
if not self._connector:
|
||||
raise RuntimeError(
|
||||
"Event log manager not initialized. Call initialize() first."
|
||||
)
|
||||
return self._connector
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all open connectors"""
|
||||
assert self._connector is not None
|
||||
await self._connector.close()
|
||||
logger.bind(user_facing=True).info("Closed db connector")
|
||||
self._connector = None
|
||||
@@ -1,15 +0,0 @@
|
||||
"""SQLite event storage implementation."""
|
||||
|
||||
from .config import EventLogConfig, EventLogType
|
||||
from .connector import AsyncSQLiteEventStorage
|
||||
from .event_log_manager import EventLogManager
|
||||
from .types import EventStorageProtocol, StoredEvent
|
||||
|
||||
__all__ = [
|
||||
"AsyncSQLiteEventStorage",
|
||||
"EventLogConfig",
|
||||
"EventLogManager",
|
||||
"EventLogType",
|
||||
"EventStorageProtocol",
|
||||
"StoredEvent",
|
||||
]
|
||||
@@ -1,32 +0,0 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.constants import EXO_GLOBAL_EVENT_DB, EXO_WORKER_EVENT_DB
|
||||
|
||||
|
||||
class EventLogType(str, Enum):
|
||||
"""Types of event logs in the system"""
|
||||
|
||||
WORKER_EVENTS = "worker_events"
|
||||
GLOBAL_EVENTS = "global_events"
|
||||
|
||||
|
||||
class EventLogConfig(BaseModel):
|
||||
"""Configuration for the event log system"""
|
||||
|
||||
# Batch processing settings
|
||||
batch_size: int = 100
|
||||
batch_timeout_ms: int = 100
|
||||
debounce_ms: int = 10
|
||||
max_age_ms: int = 100
|
||||
|
||||
def get_db_path(self, log_type: EventLogType) -> Path:
|
||||
"""Get the full path for a specific event log type"""
|
||||
if log_type == EventLogType.WORKER_EVENTS:
|
||||
return EXO_WORKER_EVENT_DB
|
||||
elif log_type == EventLogType.GLOBAL_EVENTS:
|
||||
return EXO_GLOBAL_EVENT_DB
|
||||
else:
|
||||
raise ValueError(f"Unknown log type: {log_type}")
|
||||
@@ -1,122 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from exo.shared.constants import EXO_HOME
|
||||
from exo.shared.db.sqlite.config import EventLogConfig, EventLogType
|
||||
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from exo.shared.utils.fs import ensure_directory_exists
|
||||
|
||||
|
||||
class EventLogManager:
|
||||
"""
|
||||
Manages both worker and global event log connectors.
|
||||
Used by both master and worker processes with different access patterns:
|
||||
|
||||
- Worker: writes to worker_events, tails global_events
|
||||
- Master (elected): writes to global_events, tails global_events
|
||||
- Master (replica): writes to worker_events, tails global_events
|
||||
"""
|
||||
|
||||
def __init__(self, config: EventLogConfig):
|
||||
self._config = config
|
||||
self._connectors: Dict[EventLogType, AsyncSQLiteEventStorage] = {}
|
||||
|
||||
# Ensure base directory exists
|
||||
ensure_directory_exists(EXO_HOME)
|
||||
|
||||
# TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor.
|
||||
async def initialize(self, max_retries: int = 3) -> None:
|
||||
"""Initialize both connectors with retry logic - call this during startup"""
|
||||
# Both master and worker need both connectors
|
||||
for log_type in [EventLogType.WORKER_EVENTS, EventLogType.GLOBAL_EVENTS]:
|
||||
retry_count: int = 0
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self.get_connector(log_type)
|
||||
break
|
||||
except OperationalError as e:
|
||||
last_error = e
|
||||
if "database is locked" in str(e) and retry_count < max_retries - 1:
|
||||
retry_count += 1
|
||||
delay = cast(float, 0.5 * (2**retry_count))
|
||||
logger.warning(
|
||||
f"Database locked while initializing {log_type.value}, retry {retry_count}/{max_retries} after {delay}s"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.opt(exception=e).error(
|
||||
f"Failed to initialize {log_type.value} after {retry_count + 1} attempts"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Could not initialize {log_type.value} database after {retry_count + 1} attempts"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error(
|
||||
f"Unexpected error initializing {log_type.value}"
|
||||
)
|
||||
raise
|
||||
|
||||
if retry_count >= max_retries and last_error:
|
||||
raise RuntimeError(
|
||||
f"Could not initialize {log_type.value} database after {max_retries} attempts"
|
||||
) from last_error
|
||||
logger.bind(user_facing=True).info("Initialized all event log connectors")
|
||||
|
||||
async def get_connector(self, log_type: EventLogType) -> AsyncSQLiteEventStorage:
|
||||
"""Get or create a connector for the specified log type"""
|
||||
if log_type not in self._connectors:
|
||||
db_path = self._config.get_db_path(log_type)
|
||||
|
||||
try:
|
||||
connector = AsyncSQLiteEventStorage(
|
||||
db_path=db_path,
|
||||
batch_size=self._config.batch_size,
|
||||
batch_timeout_ms=self._config.batch_timeout_ms,
|
||||
debounce_ms=self._config.debounce_ms,
|
||||
max_age_ms=self._config.max_age_ms,
|
||||
)
|
||||
|
||||
# Start the connector (creates tables if needed)
|
||||
await connector.start()
|
||||
|
||||
self._connectors[log_type] = connector
|
||||
logger.bind(user_facing=True).info(
|
||||
f"Initialized {log_type.value} connector at {db_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.bind(user_facing=True).opt(exception=e).error(
|
||||
f"Failed to create {log_type.value} connector"
|
||||
)
|
||||
raise
|
||||
|
||||
return self._connectors[log_type]
|
||||
|
||||
@property
|
||||
def worker_events(self) -> AsyncSQLiteEventStorage:
|
||||
"""Access worker events log (must call initialize() first)"""
|
||||
if EventLogType.WORKER_EVENTS not in self._connectors:
|
||||
raise RuntimeError(
|
||||
"Event log manager not initialized. Call initialize() first."
|
||||
)
|
||||
return self._connectors[EventLogType.WORKER_EVENTS]
|
||||
|
||||
@property
|
||||
def global_events(self) -> AsyncSQLiteEventStorage:
|
||||
"""Access global events log (must call initialize() first)"""
|
||||
if EventLogType.GLOBAL_EVENTS not in self._connectors:
|
||||
raise RuntimeError(
|
||||
"Event log manager not initialized. Call initialize() first."
|
||||
)
|
||||
return self._connectors[EventLogType.GLOBAL_EVENTS]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all open connectors"""
|
||||
for log_type, connector in self._connectors.items():
|
||||
await connector.close()
|
||||
logger.bind(user_facing=True).info(f"Closed {log_type.value} connector")
|
||||
self._connectors.clear()
|
||||
@@ -1,13 +1,9 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Protocol, Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import DateTime, Index
|
||||
from sqlmodel import JSON, Column, Field, SQLModel
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events.components import EventFromEventLog
|
||||
|
||||
|
||||
class StoredEvent(SQLModel, table=True):
|
||||
"""SQLite representation of an event in the event log.
|
||||
@@ -29,28 +25,3 @@ class StoredEvent(SQLModel, table=True):
|
||||
)
|
||||
|
||||
__table_args__ = (Index("idx_events_origin_created", "origin", "created_at"),)
|
||||
|
||||
|
||||
class EventStorageProtocol(Protocol):
|
||||
"""Protocol for event storage implementations."""
|
||||
|
||||
async def append_events(self, events: Sequence[Event], origin: NodeId) -> None:
|
||||
"""Append events to the log (fire-and-forget).
|
||||
|
||||
Events are queued for batched writing and assigned idx_in_log
|
||||
when committed to storage.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_events_since(
|
||||
self, last_idx: int
|
||||
) -> Sequence[EventFromEventLog[Event]]:
|
||||
"""Retrieve events after a specific index.
|
||||
|
||||
Returns events in idx_in_log order.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the storage connection and cleanup resources."""
|
||||
...
|
||||
183
src/exo/shared/election.py
Normal file
183
src/exo/shared/election.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from typing import Self
|
||||
|
||||
import anyio
|
||||
from anyio import (
|
||||
CancelScope,
|
||||
Event,
|
||||
create_task_group,
|
||||
get_cancelled_exc_class,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
ELECTION_TIMEOUT = 3.0
|
||||
|
||||
|
||||
class ElectionMessage(CamelCaseModel):
|
||||
clock: int
|
||||
seniority: int
|
||||
node_id: NodeId
|
||||
|
||||
# Could eventually include a list of neighbour nodes for centrality
|
||||
def __lt__(self, other: Self):
|
||||
if self.seniority != other.seniority:
|
||||
return self.seniority < other.seniority
|
||||
else:
|
||||
return self.node_id < other.node_id
|
||||
|
||||
|
||||
class ElectionResult(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
is_new_master: bool
|
||||
historic_messages: list[ConnectionMessage]
|
||||
|
||||
|
||||
class Election:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
election_message_receiver: Receiver[ElectionMessage],
|
||||
election_message_sender: Sender[ElectionMessage],
|
||||
election_result_sender: Sender[ElectionResult],
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
*,
|
||||
is_candidate: bool = True,
|
||||
seniority: int = 0,
|
||||
):
|
||||
# If we aren't a candidate, simply don't increment seniority.
|
||||
# For reference: This node can be elected master if all nodes are not master candidates
|
||||
# Any master candidate will automatically win out over this node.
|
||||
self.seniority = seniority if is_candidate else -1
|
||||
self.clock = 0
|
||||
self.node_id = node_id
|
||||
# Every node spawns as master
|
||||
self.master_node_id: NodeId = node_id
|
||||
|
||||
self._em_sender = election_message_sender
|
||||
self._em_receiver = election_message_receiver
|
||||
self._er_sender = election_result_sender
|
||||
self._cm_receiver = connection_message_receiver
|
||||
|
||||
# Campaign state
|
||||
self._candidates: list[ElectionMessage] = []
|
||||
self._campaign_cancel_scope: CancelScope | None = None
|
||||
self._campaign_done: Event | None = None
|
||||
self._tg: TaskGroup | None = None
|
||||
self._connection_messages: list[ConnectionMessage] = []
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
await self._campaign(None)
|
||||
|
||||
if self._campaign_cancel_scope is not None:
|
||||
self._campaign_cancel_scope.cancel()
|
||||
# Only exit once the latest campaign has finished
|
||||
if self._campaign_done is not None:
|
||||
await self._campaign_done.wait()
|
||||
|
||||
async def elect(self, node_id: NodeId) -> None:
|
||||
is_new_master = node_id != self.master_node_id
|
||||
self.master_node_id = node_id
|
||||
await self._er_sender.send(
|
||||
ElectionResult(
|
||||
node_id=node_id,
|
||||
is_new_master=is_new_master,
|
||||
historic_messages=self._connection_messages,
|
||||
)
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._tg:
|
||||
logger.warning(
|
||||
"Attempted to shutdown election service that was not running"
|
||||
)
|
||||
return
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _election_receiver(self) -> None:
|
||||
with self._em_receiver as election_messages:
|
||||
async for message in election_messages:
|
||||
if message.node_id == self.node_id:
|
||||
# Drop messages from us (See exo.routing.router)
|
||||
continue
|
||||
# If a new round is starting, we participate
|
||||
if message.clock > self.clock:
|
||||
self.clock = message.clock
|
||||
await self._campaign(message)
|
||||
continue
|
||||
# Dismiss old messages
|
||||
if message.clock < self.clock:
|
||||
continue
|
||||
logger.debug(f"Election added candidate {message}")
|
||||
# Now we are processing this rounds messages - including the message that triggered this round.
|
||||
self._candidates.append(message)
|
||||
|
||||
async def _connection_receiver(self) -> None:
|
||||
with self._cm_receiver as connection_messages:
|
||||
async for msg in connection_messages:
|
||||
# These messages are strictly peer to peer
|
||||
self.clock += 1
|
||||
await self._campaign(None)
|
||||
self._connection_messages.append(msg)
|
||||
|
||||
async def _campaign(self, initial_message: ElectionMessage | None) -> None:
|
||||
# Kill the old campaign
|
||||
if self._campaign_cancel_scope:
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done:
|
||||
await self._campaign_done.wait()
|
||||
|
||||
candidates: list[ElectionMessage] = []
|
||||
if initial_message:
|
||||
candidates.append(initial_message)
|
||||
self._candidates = candidates
|
||||
done = Event()
|
||||
self._campaign_done = done
|
||||
|
||||
assert self._tg is not None, (
|
||||
"Election campaign started before election service initialized"
|
||||
)
|
||||
# Spin off a new campaign
|
||||
self._tg.start_soon(self._complete_campaign, self.clock, candidates, done)
|
||||
|
||||
async def _complete_campaign(
|
||||
self, clock: int, candidates: list[ElectionMessage], done: Event
|
||||
) -> None:
|
||||
scope = CancelScope()
|
||||
try:
|
||||
with scope:
|
||||
self._campaign_cancel_scope = scope
|
||||
logger.info(f"Election {clock} started")
|
||||
|
||||
candidates.append(self._election_status(clock))
|
||||
await self._em_sender.send(self._election_status(clock))
|
||||
|
||||
await anyio.sleep(ELECTION_TIMEOUT)
|
||||
|
||||
# Election finished!
|
||||
candidates = sorted(candidates)
|
||||
logger.debug(f"Election queue {candidates}")
|
||||
elected = candidates[-1]
|
||||
logger.info("Election finished")
|
||||
if self.node_id == elected.node_id and self.seniority >= 0:
|
||||
self.seniority = max(self.seniority, len(candidates))
|
||||
await self.elect(elected.node_id)
|
||||
except get_cancelled_exc_class():
|
||||
logger.info("Election cancelled")
|
||||
finally:
|
||||
if self._campaign_cancel_scope is scope:
|
||||
self._campaign_cancel_scope = None
|
||||
done.set()
|
||||
|
||||
def _election_status(self, clock: int | None = None) -> ElectionMessage:
|
||||
c = self.clock if clock is None else clock
|
||||
return ElectionMessage(clock=c, seniority=self.seniority, node_id=self.node_id)
|
||||
@@ -1,28 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
env_model_config = ConfigDict(
|
||||
strict=True,
|
||||
frozen=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(BaseModel):
|
||||
model_config = env_model_config
|
||||
|
||||
|
||||
EnvSchema = TypeVar("EnvSchema", bound=BaseEnv)
|
||||
|
||||
|
||||
def get_validated_env(
|
||||
environment_schema: type[EnvSchema], logger: logging.Logger
|
||||
) -> EnvSchema:
|
||||
try:
|
||||
return environment_schema.model_validate(os.environ, strict=True)
|
||||
except ValidationError as e:
|
||||
logger.error("Environment Variables Validation Failed: %s", e)
|
||||
raise e
|
||||
@@ -18,6 +18,7 @@ class AsyncConnection[SendT, RecvT]:
|
||||
- await send(...) from asyncio code
|
||||
- send_sync(...) from executor/background threads
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection):
|
||||
self._conn = conn
|
||||
self._send_lock = threading.Lock()
|
||||
@@ -44,7 +45,7 @@ class AsyncConnection[SendT, RecvT]:
|
||||
def _recv_blocking(self) -> RecvT:
|
||||
# Not strictly needed in your parent, but safe if misused elsewhere
|
||||
with self._recv_lock:
|
||||
return self._conn.recv() # type: ignore[no-any-return]
|
||||
return self._conn.recv() # type: ignore[no-any-return]
|
||||
|
||||
async def poll(self, timeout: float | None = None) -> bool:
|
||||
return await asyncio.to_thread(self._conn.poll, timeout)
|
||||
@@ -52,12 +53,15 @@ class AsyncConnection[SendT, RecvT]:
|
||||
def close(self) -> None:
|
||||
self._conn.close()
|
||||
|
||||
|
||||
_conn: Optional[AsyncConnection[RunnerResponse, RunnerMessage]] = None
|
||||
|
||||
|
||||
def set_conn(c: AsyncConnection[RunnerResponse, RunnerMessage]) -> None:
|
||||
global _conn
|
||||
_conn = c
|
||||
|
||||
|
||||
def get_conn() -> AsyncConnection[RunnerResponse, RunnerMessage]:
|
||||
if _conn is None:
|
||||
raise RuntimeError("Global conn has not been set yet")
|
||||
|
||||
@@ -12,7 +12,7 @@ import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exo.shared.utils.fs import StrPath, ensure_parent_directory_exists
|
||||
from exo.utils.fs import StrPath, ensure_parent_directory_exists
|
||||
|
||||
# open in read-write mode, creates file if it doesn't exist already,
|
||||
# closes this file descriptor in any children processes (prevents FD leaking),
|
||||
|
||||
@@ -33,7 +33,7 @@ from typing import Callable
|
||||
from cobs import cobs # pyright: ignore[reportMissingTypeStubs]
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from exo.shared.utils.fs import (
|
||||
from exo.utils.fs import (
|
||||
StrPath,
|
||||
delete_if_exists,
|
||||
ensure_parent_directory_exists,
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import final
|
||||
|
||||
import base58
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from filelock import FileLock
|
||||
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
|
||||
|
||||
@final
|
||||
class PeerId:
|
||||
"""
|
||||
A libp2p peer identifier derived from a cryptographic public key.
|
||||
Compatible with py-libp2p's PeerID interface.
|
||||
"""
|
||||
|
||||
def __init__(self, peer_id_bytes: bytes) -> None:
|
||||
self._bytes = peer_id_bytes
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> "PeerId":
|
||||
"""Create PeerId from raw bytes."""
|
||||
return PeerId(data)
|
||||
|
||||
@staticmethod
|
||||
def from_public_key(public_key_bytes: bytes) -> "PeerId":
|
||||
"""Create PeerId from a public key by hashing it."""
|
||||
# For Ed25519 keys, libp2p uses the identity hash (no hashing) for keys <= 42 bytes
|
||||
# Since Ed25519 public keys are 32 bytes, we use identity hash
|
||||
if len(public_key_bytes) <= 42:
|
||||
return PeerId(public_key_bytes)
|
||||
else:
|
||||
# For larger keys, use SHA-256
|
||||
hash_digest = hashlib.sha256(public_key_bytes).digest()
|
||||
return PeerId(hash_digest)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
"""Return the raw bytes of this PeerId."""
|
||||
return self._bytes
|
||||
|
||||
def to_base58(self) -> str:
|
||||
"""Return the base58-encoded string representation."""
|
||||
return base58.b58encode(self._bytes).decode("ascii")
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the base58-encoded string representation."""
|
||||
return self.to_base58()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return debug representation."""
|
||||
return f"PeerId('{self.to_base58()}')"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check equality with another PeerId."""
|
||||
if not isinstance(other, PeerId):
|
||||
return False
|
||||
return self._bytes == other._bytes
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Make PeerId hashable."""
|
||||
return hash(self._bytes)
|
||||
|
||||
|
||||
@final
|
||||
class Keypair:
|
||||
"""
|
||||
A py-libp2p compatible keypair implementation.
|
||||
Provides the same interface as py-libp2p's KeyPair.
|
||||
"""
|
||||
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
|
||||
@staticmethod
|
||||
def generate_ed25519() -> "Keypair":
|
||||
"""Generate a new Ed25519 keypair."""
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
return Keypair(private_key)
|
||||
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(data: bytes) -> "Keypair":
|
||||
"""
|
||||
Deserialize a keypair from libp2p protobuf encoding.
|
||||
Compatible with py-libp2p's serialization format.
|
||||
"""
|
||||
if len(data) < 2:
|
||||
raise ValueError("Invalid protobuf data: too short")
|
||||
|
||||
# Simple protobuf parsing for our specific use case
|
||||
# We expect: field 1 (type) as varint, field 2 (data) as bytes
|
||||
offset = 0
|
||||
|
||||
# Parse type field (field tag 1, wire type 0 = varint)
|
||||
if data[offset] != 0x08: # field 1, varint
|
||||
raise ValueError("Expected type field")
|
||||
offset += 1
|
||||
|
||||
key_type = data[offset]
|
||||
offset += 1
|
||||
|
||||
if key_type != 1: # Ed25519
|
||||
raise ValueError(f"Unsupported key type: {key_type}")
|
||||
|
||||
# Parse data field (field tag 2, wire type 2 = length-delimited)
|
||||
if offset >= len(data) or data[offset] != 0x12: # field 2, bytes
|
||||
raise ValueError("Expected data field")
|
||||
offset += 1
|
||||
|
||||
# Parse length
|
||||
data_length = data[offset]
|
||||
offset += 1
|
||||
|
||||
if data_length not in (32, 64):
|
||||
raise ValueError(f"Invalid Ed25519 private key length: {data_length}")
|
||||
|
||||
if offset + data_length > len(data):
|
||||
raise ValueError("Truncated private key data")
|
||||
|
||||
key_data = data[offset : offset + data_length]
|
||||
|
||||
try:
|
||||
if data_length == 64:
|
||||
# libp2p format: 32 bytes private key seed + 32 bytes public key
|
||||
private_key_seed = key_data[:32]
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(
|
||||
private_key_seed
|
||||
)
|
||||
else:
|
||||
# Raw 32-byte private key
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(key_data)
|
||||
|
||||
return Keypair(private_key)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid Ed25519 private key: {e}") from e
|
||||
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
"""
|
||||
Serialize this keypair to libp2p protobuf encoding.
|
||||
Compatible with py-libp2p's serialization format.
|
||||
"""
|
||||
private_key_bytes = self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
public_key_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||
)
|
||||
|
||||
# libp2p Ed25519 format: private key seed (32) + public key (32)
|
||||
combined_key_data = private_key_bytes + public_key_bytes
|
||||
|
||||
# Build protobuf manually for our simple case
|
||||
# Field 1 (type): tag=0x08, value=1 (Ed25519)
|
||||
# Field 2 (data): tag=0x12, length=64, data=combined_key_data
|
||||
result = bytearray()
|
||||
result.extend([0x08, 0x01]) # field 1: type = 1 (Ed25519)
|
||||
result.extend([0x12, 0x40]) # field 2: length = 64 bytes
|
||||
result.extend(combined_key_data)
|
||||
|
||||
return bytes(result)
|
||||
|
||||
def to_peer_id(self) -> PeerId:
|
||||
"""Generate a PeerId from this keypair's public key."""
|
||||
public_key_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||
)
|
||||
return PeerId.from_public_key(public_key_bytes)
|
||||
|
||||
def sign(self, data: bytes) -> bytes:
|
||||
"""Sign data with this keypair's private key."""
|
||||
return self._private_key.sign(data)
|
||||
|
||||
def verify(self, data: bytes, signature: bytes) -> bool:
|
||||
"""Verify a signature against data using this keypair's public key."""
|
||||
try:
|
||||
self._public_key.verify(signature, data)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def public_key_bytes(self) -> bytes:
|
||||
"""Get the raw public key bytes."""
|
||||
return self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||
)
|
||||
|
||||
@property
|
||||
def private_key_bytes(self) -> bytes:
|
||||
"""Get the raw private key bytes."""
|
||||
return self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
# py-libp2p compatibility properties
|
||||
@property
|
||||
def private_key(self) -> ed25519.Ed25519PrivateKey:
|
||||
"""Access to the underlying private key for py-libp2p compatibility."""
|
||||
return self._private_key
|
||||
|
||||
@property
|
||||
def public_key(self) -> ed25519.Ed25519PublicKey:
|
||||
"""Access to the underlying public key for py-libp2p compatibility."""
|
||||
return self._public_key
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR,
|
||||
) -> Keypair:
|
||||
"""
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
|
||||
def lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
# operate with cross-process lock to avoid race conditions
|
||||
with FileLock(lock_path(path)):
|
||||
with open(path, "a+b") as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logging.warning(
|
||||
f"Encountered error when trying to get keypair: {e}"
|
||||
)
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
return keypair
|
||||
@@ -1,32 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import loguru
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_TEST_LOG
|
||||
|
||||
|
||||
def is_user_facing(record: loguru.Record) -> bool:
|
||||
return ("user_facing" in record["extra"]) and record["extra"]["user_facing"]
|
||||
|
||||
|
||||
def logger_setup(log_file: Path, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
logger.remove()
|
||||
if verbosity == 0:
|
||||
_ = logger.add( # type: ignore
|
||||
sys.__stderr__, # type: ignore
|
||||
format="[ {time:hh:mm:ss.SSSSA} | <level>{level: <8}</level>] <level>{message}</level>",
|
||||
level="INFO",
|
||||
colorize=True,
|
||||
enqueue=True,
|
||||
filter=is_user_facing,
|
||||
)
|
||||
elif verbosity == 1:
|
||||
_ = logger.add( # type: ignore
|
||||
sys.__stderr__, # type: ignore
|
||||
format="[ {time:hh:mm:ss.SSSSA} | <level>{level: <8}</level>] <level>{message}</level>",
|
||||
@@ -40,11 +21,12 @@ def logger_setup(log_file: Path, verbosity: int = 0):
|
||||
format="[ {time:HH:mm:ss.SSS} | <level>{level: <8}</level> | {name}:{function}:{line} ] <level>{message}</level>",
|
||||
level="DEBUG",
|
||||
colorize=True,
|
||||
enqueue=True,
|
||||
)
|
||||
_ = logger.add(
|
||||
log_file,
|
||||
format="[ {time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} ] {message}",
|
||||
level="DEBUG",
|
||||
level="INFO",
|
||||
enqueue=True,
|
||||
)
|
||||
|
||||
@@ -52,10 +34,3 @@ def logger_setup(log_file: Path, verbosity: int = 0):
|
||||
def logger_cleanup():
|
||||
"""Flush all queues before shutting down so any in-flight logs are written to disk"""
|
||||
logger.complete()
|
||||
|
||||
|
||||
def logger_test_install(py_logger: Logger):
|
||||
"""Installs a default python logger into the Loguru environment by capturing all its handlers - intended to be used for pytest compatibility, not within the main codebase"""
|
||||
logger_setup(EXO_TEST_LOG, 3)
|
||||
for handler in py_logger.handlers:
|
||||
logger.add(handler)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
class ModelCard(CamelCaseModel):
|
||||
short_id: str
|
||||
model_id: str
|
||||
name: str
|
||||
@@ -23,9 +23,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
storage_size_kilobytes=409706307,
|
||||
storage_size=Memory.from_kb(409706307),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
@@ -36,9 +36,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
storage_size_kilobytes=754706307,
|
||||
storage_size=Memory.from_kb(754706307),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
@@ -49,9 +49,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-V3.1-8bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size_kilobytes=754706307,
|
||||
storage_size=Memory.from_kb(754706307),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
@@ -62,9 +62,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-V3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size_kilobytes=754706307 // 2, # TODO !!!!!
|
||||
storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
@@ -76,9 +76,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
storage_size_kilobytes=409706307,
|
||||
storage_size=Memory.from_kb(409706307),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
@@ -89,9 +89,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
storage_size_kilobytes=754998771712 // 1024,
|
||||
storage_size=Memory.from_bytes(754998771712),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
@@ -103,9 +103,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 8B",
|
||||
storage_size_kilobytes=4411528,
|
||||
storage_size=Memory.from_kb(4411528),
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
@@ -116,9 +116,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 70B",
|
||||
storage_size_kilobytes=38758160,
|
||||
storage_size=Memory.from_kb(38758160),
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
@@ -130,9 +130,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
storage_size_kilobytes=678948,
|
||||
storage_size=Memory.from_kb(678948),
|
||||
n_layers=16,
|
||||
),
|
||||
),
|
||||
@@ -143,9 +143,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 3B",
|
||||
storage_size_kilobytes=1765062,
|
||||
storage_size=Memory.from_kb(1765062),
|
||||
n_layers=28,
|
||||
),
|
||||
),
|
||||
@@ -157,9 +157,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size_kilobytes=38758160,
|
||||
storage_size=Memory.from_kb(38758160),
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
@@ -171,9 +171,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k",
|
||||
storage_size_kilobytes=2099262,
|
||||
storage_size=Memory.from_kb(2099262),
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
@@ -184,9 +184,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k",
|
||||
storage_size_kilobytes=2099262,
|
||||
storage_size=Memory.from_kb(2099262),
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
@@ -198,9 +198,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Qwen3-0.6B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
pretty_name="Qwen3 0.6B",
|
||||
storage_size_kilobytes=327512,
|
||||
storage_size=Memory.from_kb(327512),
|
||||
n_layers=28,
|
||||
),
|
||||
),
|
||||
@@ -211,9 +211,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Qwen3-30B-A3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
pretty_name="Qwen3 30B (Active 3B)",
|
||||
storage_size_kilobytes=16772092,
|
||||
storage_size=Memory.from_kb(16772092),
|
||||
n_layers=48,
|
||||
),
|
||||
),
|
||||
@@ -225,9 +225,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/granite-3.3-2b-instruct-fp16",
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 2B",
|
||||
storage_size_kilobytes=4948320,
|
||||
storage_size=Memory.from_kb(4948320),
|
||||
n_layers=40,
|
||||
),
|
||||
),
|
||||
@@ -238,9 +238,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/granite-3.3-8b-instruct-fp16",
|
||||
model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 8B",
|
||||
storage_size_kilobytes=15958720,
|
||||
storage_size=Memory.from_kb(15958720),
|
||||
n_layers=40,
|
||||
),
|
||||
),
|
||||
@@ -252,9 +252,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/SmolLM-135M-4bit",
|
||||
model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
pretty_name="Smol LM 135M",
|
||||
storage_size_kilobytes=73940,
|
||||
storage_size=Memory.from_kb(73940),
|
||||
n_layers=30,
|
||||
),
|
||||
),
|
||||
|
||||
@@ -6,7 +6,8 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
@@ -65,7 +66,7 @@ async def get_config_data(model_id: str) -> ConfigData:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def get_safetensors_size(model_id: str) -> int:
|
||||
async def get_safetensors_size(model_id: str) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
@@ -83,12 +84,12 @@ async def get_safetensors_size(model_id: str) -> int:
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return metadata.total_size
|
||||
return Memory.from_bytes(metadata.total_size)
|
||||
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is None:
|
||||
raise ValueError(f"No safetensors info found for {model_id}")
|
||||
return info.safetensors.total
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
|
||||
_model_meta_cache: Dict[str, ModelMetadata] = {}
|
||||
@@ -109,8 +110,8 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=model_id,
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_id,
|
||||
storage_size_kilobytes=mem_size_bytes // 1024,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
)
|
||||
|
||||
313
src/exo/shared/tests/test_election.py
Normal file
313
src/exo/shared/tests/test_election.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.channels import channel
|
||||
|
||||
# ======= #
|
||||
# Helpers #
|
||||
# ======= #
|
||||
|
||||
|
||||
def em(clock: int, seniority: int, node_id: str) -> ElectionMessage:
|
||||
return ElectionMessage(clock=clock, seniority=seniority, node_id=NodeId(node_id))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fast_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
# Keep campaigns fast; user explicitly allows tests to shorten the timeout.
|
||||
import exo.shared.election as election_mod
|
||||
|
||||
monkeypatch.setattr(election_mod, "ELECTION_TIMEOUT", 0.05, raising=True)
|
||||
yield
|
||||
|
||||
|
||||
# ======================================= #
|
||||
# TESTS #
|
||||
# ======================================= #
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
|
||||
fast_timeout: None,
|
||||
) -> None:
|
||||
"""
|
||||
Start a round by injecting an ElectionMessage with higher clock.
|
||||
With only our node effectively 'winning', we should broadcast once and update seniority.
|
||||
"""
|
||||
# Outbound election messages from the Election (we'll observe these)
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
# Inbound election messages to the Election (we'll inject these)
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
# Election results produced by the Election (we'll observe these)
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
# Connection messages (unused in this test but required by ctor)
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("B"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with create_task_group() as tg:
|
||||
with fail_after(2):
|
||||
tg.start_soon(election.run)
|
||||
# Trigger new round at clock=1 (peer announces it)
|
||||
await em_in_tx.send(em(clock=1, seniority=0, node_id="A"))
|
||||
|
||||
# Expect our broadcast back to the peer side for this round only
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 1 and got.node_id == NodeId("B"):
|
||||
break
|
||||
|
||||
# Wait for the round to finish and produce an ElectionResult
|
||||
result = await er_rx.receive()
|
||||
assert result.node_id == NodeId("B")
|
||||
# We spawned as master; electing ourselves again is not "new master".
|
||||
assert result.is_new_master is False
|
||||
|
||||
# Close inbound streams to end the receivers (and run())
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
|
||||
# We should have updated seniority to 2 (A + B).
|
||||
assert election.seniority == 2
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_peer_with_higher_seniority_wins_and_we_switch_master(
|
||||
fast_timeout: None,
|
||||
) -> None:
|
||||
"""
|
||||
If a peer with clearly higher seniority participates in the round, they should win.
|
||||
We should broadcast our status exactly once for this round, then switch master.
|
||||
"""
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with create_task_group() as tg:
|
||||
with fail_after(2):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Start round with peer's message (higher seniority)
|
||||
await em_in_tx.send(em(clock=1, seniority=10, node_id="PEER"))
|
||||
|
||||
# We should still broadcast our status exactly once for this round
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 1:
|
||||
assert got.seniority == 0
|
||||
break
|
||||
|
||||
# After the timeout, election result should report the peer as master
|
||||
result = await er_rx.receive()
|
||||
assert result.node_id == NodeId("PEER")
|
||||
assert result.is_new_master is True
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
|
||||
# We lost → seniority unchanged
|
||||
assert election.seniority == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ignores_older_messages(fast_timeout: None) -> None:
|
||||
"""
|
||||
Messages with a lower clock than the current round are ignored by the receiver.
|
||||
Expect exactly one broadcast for the higher clock round.
|
||||
"""
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, _er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with create_task_group() as tg:
|
||||
with fail_after(2):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Newer round arrives first -> triggers campaign at clock=2
|
||||
await em_in_tx.send(em(clock=2, seniority=0, node_id="A"))
|
||||
while True:
|
||||
first = await em_out_rx.receive()
|
||||
if first.clock == 2:
|
||||
break
|
||||
|
||||
# Older message (clock=1) must be ignored (no second broadcast)
|
||||
await em_in_tx.send(em(clock=1, seniority=999, node_id="B"))
|
||||
|
||||
got_second = False
|
||||
with move_on_after(0.2):
|
||||
_ = await em_out_rx.receive()
|
||||
got_second = True
|
||||
assert not got_second, "Should not receive a broadcast for an older round"
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
|
||||
# Not asserting on the result; focus is on ignore behavior.
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
|
||||
fast_timeout: None,
|
||||
) -> None:
|
||||
"""
|
||||
Two successive rounds → two broadcasts. Second round triggered by a higher-clock message.
|
||||
"""
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, _er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with create_task_group() as tg:
|
||||
with fail_after(2):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Round 1 at clock=1
|
||||
await em_in_tx.send(em(clock=1, seniority=0, node_id="X"))
|
||||
while True:
|
||||
m1 = await em_out_rx.receive()
|
||||
if m1.clock == 1:
|
||||
break
|
||||
|
||||
# Round 2 at clock=2
|
||||
await em_in_tx.send(em(clock=2, seniority=0, node_id="Y"))
|
||||
while True:
|
||||
m2 = await em_out_rx.receive()
|
||||
if m2.clock == 2:
|
||||
break
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
|
||||
# Not asserting on who won; just that both rounds were broadcast.
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -> None:
|
||||
"""
|
||||
When we win against two peers in the same round, our seniority becomes
|
||||
max(existing, number_of_candidates). With existing=0: expect 3 (us + A + B).
|
||||
"""
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with create_task_group() as tg:
|
||||
with fail_after(2):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Start round at clock=7 with two peer participants
|
||||
await em_in_tx.send(em(clock=7, seniority=0, node_id="A"))
|
||||
await em_in_tx.send(em(clock=7, seniority=0, node_id="B"))
|
||||
|
||||
# We should see exactly one broadcast from us for this round
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 7 and got.node_id == NodeId("ME"):
|
||||
break
|
||||
|
||||
# Wait for the election to finish so seniority updates
|
||||
_ = await er_rx.receive()
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
|
||||
# We + A + B = 3 → new seniority expected to be 3
|
||||
assert election.seniority == 3
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_connection_message_triggers_new_round_broadcast(
|
||||
fast_timeout: None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection message increments the clock and starts a new campaign.
|
||||
We should observe a broadcast at the incremented clock.
|
||||
"""
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, _er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with create_task_group() as tg:
|
||||
with fail_after(2):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Send any connection message object; we close quickly to cancel before result creation
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
)
|
||||
)
|
||||
|
||||
# Expect a broadcast for the new round at clock=1
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 1 and got.node_id == NodeId("ME"):
|
||||
break
|
||||
|
||||
# Close promptly to avoid waiting for campaign completion
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
|
||||
# After cancellation (before election finishes), no seniority changes asserted here.
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.ipc.file_mutex.flock_mutex import FlockMutex, LockType
|
||||
from exo.shared.utils.fs import delete_if_exists, make_temp_path
|
||||
from exo.utils.fs import delete_if_exists, make_temp_path
|
||||
|
||||
|
||||
def test_lock_held():
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
@@ -13,8 +11,8 @@ from typing import Optional
|
||||
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.shared.keypair import get_node_id_keypair
|
||||
|
||||
NUM_CONCURRENT_PROCS = 10
|
||||
|
||||
|
||||
@@ -1,612 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from exo.shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import ChunkGenerated
|
||||
from exo.shared.types.events.chunks import ChunkType, TokenChunk
|
||||
|
||||
# Type ignore comment for all protected member access in this test file
|
||||
# pyright: reportPrivateUsage=false
|
||||
|
||||
|
||||
def _load_json_data(raw_data: str) -> dict[str, Any]:
|
||||
"""Helper function to load JSON data with proper typing."""
|
||||
return cast(dict[str, Any], json.loads(raw_data))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path() -> Generator[Path, None, None]:
|
||||
"""Create a temporary database file for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
yield Path(f.name)
|
||||
# Cleanup
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_node_id() -> NodeId:
|
||||
"""Create a sample NodeId for testing."""
|
||||
return NodeId()
|
||||
|
||||
|
||||
class TestAsyncSQLiteEventStorage:
|
||||
"""Test suite for AsyncSQLiteEventStorage focused on storage functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_creates_tables(self, temp_db_path: Path) -> None:
|
||||
"""Test that database initialization creates the events table."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
# Verify table exists by querying directly
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='events'"
|
||||
)
|
||||
)
|
||||
tables = result.fetchall()
|
||||
assert len(tables) == 1
|
||||
assert tables[0][0] == "events"
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_twice_raises_error(self, temp_db_path: Path) -> None:
|
||||
"""Test that starting storage twice raises an error."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Storage already started"):
|
||||
await storage.start()
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_database_operations(
|
||||
self, temp_db_path: Path, sample_node_id: NodeId
|
||||
) -> None:
|
||||
"""Test direct database operations without event parsing."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
# Insert test data directly
|
||||
test_data = {
|
||||
"event_type": "test_event",
|
||||
"test_field": "test_value",
|
||||
"number": 42,
|
||||
}
|
||||
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": sample_node_id,
|
||||
"event_type": "test_event",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(test_data),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Query data back
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0][0] == 1 # rowid
|
||||
assert rows[0][1] == sample_node_id # origin
|
||||
raw_json = cast(str, rows[0][2])
|
||||
retrieved_data = _load_json_data(raw_json)
|
||||
assert retrieved_data == test_data
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rowid_auto_increment(
|
||||
self, temp_db_path: Path, sample_node_id: NodeId
|
||||
) -> None:
|
||||
"""Test that rowid auto-increments correctly."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
# Insert multiple records
|
||||
test_records = [
|
||||
{"event_type": "test_event_1", "data": "first"},
|
||||
{"event_type": "test_event_2", "data": "second"},
|
||||
{"event_type": "test_event_3", "data": "third"},
|
||||
]
|
||||
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for record in test_records:
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": sample_node_id,
|
||||
"event_type": record["event_type"],
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(record),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Query back and verify rowid sequence
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text("SELECT rowid, event_data FROM events ORDER BY rowid")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
assert len(rows) == 3
|
||||
for i, row in enumerate(rows):
|
||||
assert row[0] == i + 1 # rowid starts at 1
|
||||
raw_json = cast(str, row[1])
|
||||
retrieved_data = _load_json_data(raw_json)
|
||||
assert retrieved_data == test_records[i]
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_last_idx(
|
||||
self, temp_db_path: Path, sample_node_id: NodeId
|
||||
) -> None:
|
||||
"""Test that rowid returns correctly from db."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
# Insert multiple records
|
||||
test_records = [
|
||||
{"event_type": "test_event_1", "data": "first"},
|
||||
{"event_type": "test_event_2", "data": "second"},
|
||||
{"event_type": "test_event_3", "data": "third"},
|
||||
]
|
||||
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for record in test_records:
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": sample_node_id,
|
||||
"event_type": record["event_type"],
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(record),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
last_idx = await storage.get_last_idx()
|
||||
assert last_idx == 3
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rowid_with_multiple_origins(self, temp_db_path: Path) -> None:
|
||||
"""Test rowid sequence across multiple origins."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
origin1 = NodeId()
|
||||
origin2 = NodeId()
|
||||
|
||||
# Insert interleaved records from different origins
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
# Origin 1 - record 1
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": origin1,
|
||||
"event_type": "event_1",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps({"from": "origin1", "seq": 1}),
|
||||
},
|
||||
)
|
||||
# Origin 2 - record 2
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": origin2,
|
||||
"event_type": "event_2",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps({"from": "origin2", "seq": 2}),
|
||||
},
|
||||
)
|
||||
# Origin 1 - record 3
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": origin1,
|
||||
"event_type": "event_3",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps({"from": "origin1", "seq": 3}),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Verify sequential rowid regardless of origin
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
assert len(rows) == 3
|
||||
assert rows[0][0] == 1 # First rowid
|
||||
assert rows[1][0] == 2 # Second rowid
|
||||
assert rows[2][0] == 3 # Third rowid
|
||||
|
||||
# Verify data integrity
|
||||
raw_json1 = cast(str, rows[0][2])
|
||||
raw_json2 = cast(str, rows[1][2])
|
||||
raw_json3 = cast(str, rows[2][2])
|
||||
data1 = _load_json_data(raw_json1)
|
||||
data2 = _load_json_data(raw_json2)
|
||||
data3 = _load_json_data(raw_json3)
|
||||
|
||||
assert data1["from"] == "origin1" and data1["seq"] == 1
|
||||
assert data2["from"] == "origin2" and data2["seq"] == 2
|
||||
assert data3["from"] == "origin1" and data3["seq"] == 3
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_events_since_index(
|
||||
self, temp_db_path: Path, sample_node_id: NodeId
|
||||
) -> None:
|
||||
"""Test querying events after a specific rowid."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
# Insert 10 test records
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for i in range(10):
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": sample_node_id,
|
||||
"event_type": f"event_{i}",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps({"index": i}),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Query events after index 5
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT rowid, event_data FROM events WHERE rowid > :last_idx ORDER BY rowid"
|
||||
),
|
||||
{"last_idx": 5},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
assert len(rows) == 5 # Should get records 6-10
|
||||
for i, row in enumerate(rows):
|
||||
assert row[0] == i + 6 # rowid 6, 7, 8, 9, 10
|
||||
raw_json = cast(str, row[1])
|
||||
data = _load_json_data(raw_json)
|
||||
assert data["index"] == i + 5 # index 5, 6, 7, 8, 9
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_query(self, temp_db_path: Path) -> None:
|
||||
"""Test querying when no events exist."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT rowid, origin, event_data FROM events WHERE rowid > :last_idx ORDER BY rowid"
|
||||
),
|
||||
{"last_idx": 0},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
assert len(rows) == 0
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_operations_after_close_raise_error(self, temp_db_path: Path) -> None:
|
||||
"""Test that operations after close work properly."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
await storage.close()
|
||||
|
||||
# These should not raise errors since we're not using the public API
|
||||
assert storage._closed is True
|
||||
assert storage._engine is not None # Engine should still exist but be disposed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_close_calls_safe(self, temp_db_path: Path) -> None:
|
||||
"""Test that multiple close calls are safe."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
await storage.close()
|
||||
await storage.close() # Should not raise an error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_data_types(
|
||||
self, temp_db_path: Path, sample_node_id: NodeId
|
||||
) -> None:
|
||||
"""Test that various JSON data types are handled correctly."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
# Test various JSON data types
|
||||
test_data = {
|
||||
"string": "test string",
|
||||
"number": 42,
|
||||
"float": 3.14,
|
||||
"boolean": True,
|
||||
"null": None,
|
||||
"array": [1, 2, 3, "four"],
|
||||
"object": {"nested": "value", "deep": {"deeper": "nested"}},
|
||||
"unicode": "测试 🚀",
|
||||
}
|
||||
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": sample_node_id,
|
||||
"event_type": "complex_event",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(test_data),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Query back and verify data integrity
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text("SELECT event_data FROM events WHERE event_type = :event_type"),
|
||||
{"event_type": "complex_event"},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
assert len(rows) == 1
|
||||
raw_json = cast(str, rows[0][0])
|
||||
retrieved_data = _load_json_data(raw_json)
|
||||
assert retrieved_data == test_data
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_inserts(self, temp_db_path: Path) -> None:
|
||||
"""Test concurrent inserts maintain rowid ordering."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
async def insert_batch(origin_id: str, batch_id: int, count: int) -> None:
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for i in range(count):
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
|
||||
),
|
||||
{
|
||||
"origin": origin_id,
|
||||
"event_type": f"batch_{batch_id}_event_{i}",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps({"batch": batch_id, "item": i}),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Run multiple concurrent insert batches
|
||||
origin1 = str(uuid4())
|
||||
origin2 = str(uuid4())
|
||||
origin3 = str(uuid4())
|
||||
|
||||
await asyncio.gather(
|
||||
insert_batch(origin1, 1, 5),
|
||||
insert_batch(origin2, 2, 5),
|
||||
insert_batch(origin3, 3, 5),
|
||||
)
|
||||
|
||||
# Verify all records were inserted and rowid is sequential
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
result = await session.execute(
|
||||
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
assert len(rows) == 15 # 3 batches * 5 records each
|
||||
|
||||
# Verify rowid sequence is maintained
|
||||
for i, row in enumerate(rows):
|
||||
assert row[0] == i + 1 # rowid should be sequential
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_generated_event_serialization(
|
||||
self, temp_db_path: Path, sample_node_id: NodeId
|
||||
) -> None:
|
||||
"""Test that ChunkGenerated event with nested types can be serialized and deserialized correctly."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(
|
||||
db_path=temp_db_path,
|
||||
batch_size=default_config.batch_size,
|
||||
batch_timeout_ms=default_config.batch_timeout_ms,
|
||||
debounce_ms=default_config.debounce_ms,
|
||||
max_age_ms=default_config.max_age_ms,
|
||||
)
|
||||
await storage.start()
|
||||
|
||||
# Create a ChunkGenerated event with nested TokenChunk
|
||||
command_id = CommandId()
|
||||
token_chunk = TokenChunk(
|
||||
text="Hello, world!",
|
||||
token_id=42,
|
||||
finish_reason="stop",
|
||||
chunk_type=ChunkType.token,
|
||||
command_id=command_id,
|
||||
idx=0,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
chunk_generated_event = ChunkGenerated(command_id=command_id, chunk=token_chunk)
|
||||
|
||||
# Store the event using the storage API
|
||||
await storage.append_events([chunk_generated_event], sample_node_id)
|
||||
|
||||
# Wait for batch to be written
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Retrieve the event
|
||||
events = await storage.get_events_since(0)
|
||||
|
||||
# Verify we got the event back
|
||||
assert len(events) == 1
|
||||
retrieved_event_wrapper = events[0]
|
||||
assert retrieved_event_wrapper.origin == sample_node_id
|
||||
|
||||
# Verify the event was deserialized correctly
|
||||
retrieved_event = retrieved_event_wrapper.event
|
||||
assert isinstance(retrieved_event, ChunkGenerated)
|
||||
assert retrieved_event.command_id == command_id
|
||||
|
||||
# Verify the nested chunk was deserialized correctly
|
||||
retrieved_chunk = retrieved_event.chunk
|
||||
assert isinstance(retrieved_chunk, TokenChunk)
|
||||
assert retrieved_chunk.chunk_type == ChunkType.token
|
||||
assert retrieved_chunk.command_id == command_id
|
||||
assert retrieved_chunk.idx == 0
|
||||
assert retrieved_chunk.model == "test-model"
|
||||
|
||||
# Verify the chunk data
|
||||
assert retrieved_chunk.text == "Hello, world!"
|
||||
assert retrieved_chunk.token_id == 42
|
||||
assert retrieved_chunk.finish_reason == "stop"
|
||||
|
||||
await storage.close()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user