big refactor

Fix. Everything.

Co-authored-by: Andrei Cravtov <the.andrei.cravtov@gmail.com>
Co-authored-by: Matt Beton <matthew.beton@gmail.com>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Seth Howes <sethshowes@gmail.com>
This commit is contained in:
Evan Quiney
2025-09-30 11:03:04 +01:00
committed by GitHub
parent 7040c9508f
commit 38ff949bf4
171 changed files with 8295 additions and 4614 deletions

View File

@@ -1,71 +0,0 @@
# Configures the Golang support and builds the forwarder
# TODO: split this up in the future as this is unrelated tasks??
# Top-level parameters that are bound to the provider flake
# These are passed from `flake.nix` using importApply
{
localSelf,
flake-parts-lib,
nixpkgs-lib,
...
}:
# These values would bind to the consumer flake when this flake module is imported:
{
config,
self,
inputs,
getSystem,
moduleWithSystem,
withSystem,
...
}:
# The actual flake-parts module configuration
{
perSystem =
{
config,
self',
inputs',
pkgs,
system,
...
}:
let
# flakeRoot = nixpkgs-lib.getExe config.flake-root.package;
#
# # Build the networking/forwarder Go utility.
# forwarder = pkgs.buildGoModule {
# pname = "exo-forwarder";
# version = "0.1.0";
# src = "${flakeRoot}/networking/forwarder";
#
# vendorHash = "sha256-BXIGg2QYqHDz2TNe8hLAGC6jVlffp9766H+WdkkuVgA=";
#
# # Only the main package at the repository root needs building.
# subPackages = [ "." ];
# };
in
{
packages = {
# inherit forwarder;
};
apps = {
# forwarder = {
# type = "app";
# program = "${forwarder}/bin/forwarder";
# };
};
make-shells.default = {
# Go 1.24 compiler align with go.mod
packages = [ pkgs.go_1_24 ];
shellHook = ''
GOPATH="''$(${nixpkgs-lib.getExe config.flake-root.package})"/.go_cache
export GOPATH
'';
};
};
}

View File

@@ -1,54 +0,0 @@
# Provides pretty banner & command index for this flake
# Top-level parameters that are bound to the provider flake
# These are passed from `flake.nix` using importApply
{
localSelf,
flake-parts-lib,
nixpkgs-lib,
just-flake,
...
}:
# These values would bind to the consumer flake when this flake module is imported:
{
config,
self,
inputs,
getSystem,
moduleWithSystem,
withSystem,
...
}:
# The actual flake-parts module configuration
{
imports = [ just-flake.flakeModule ];
perSystem =
{
config,
self',
inputs',
pkgs,
system,
...
}:
{
just-flake.features = {
# treefmt.enable = true;
# rust.enable = true;
# convco.enable = true;
# hello = {
# enable = true;
# justfile = ''
# hello:
# echo Hello World
# '';
# };
};
make-shells.default = {
inputsFrom = [ config.just-flake.outputs.devShell ];
};
};
}

View File

@@ -1,30 +0,0 @@
# Provides macmon binary for the worker.
# These values would bind to the consumer flake when this flake module is imported:
{
config,
self,
inputs,
getSystem,
moduleWithSystem,
withSystem,
...
}:
# The actual flake-parts module configuration
{
perSystem =
{
config,
self',
inputs',
pkgs,
system,
...
}:
{
make-shells.default = {
packages = if (system == "aarch64-darwin") then ([ pkgs.macmon ]) else ([]);
};
};
}

4
.gitignore vendored
View File

@@ -12,8 +12,6 @@ hosts*.json
# TODO figure out how to properly solve the issue with these target directories showing up
networking/target/
networking/topology/target/
rust/target/
rust/Cargo.lock
build/
dist/
@@ -26,4 +24,4 @@ dist/
just-flake.just
# for the gitingest enthusiasts
digest.txt
digest.txt

8
.idea/exo-v2.iml generated
View File

@@ -10,11 +10,19 @@
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/scripts/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/rust/exo_pyo3_bindings/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/rust/exo_pyo3_bindings/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/rust/util/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/rust/networking/examples" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/rust/networking/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/rust/networking/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/rust/system_custodian/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/.venv" />
<excludeFolder url="file://$MODULE_DIR$/.direnv" />
<excludeFolder url="file://$MODULE_DIR$/build" />
<excludeFolder url="file://$MODULE_DIR$/dist" />
<excludeFolder url="file://$MODULE_DIR$/.go_cache" />
<excludeFolder url="file://$MODULE_DIR$/rust/target" />
</content>
<orderEntry type="jdk" jdkName="Python 3.13 (exo)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />

2
.idea/vcs.xml generated
View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
<mapping directory="" vcs="Git" />
</component>
</project>

133
copy_model.sh Executable file
View File

@@ -0,0 +1,133 @@
#!/usr/bin/env bash
set -euo pipefail
# copy_model.sh: clone ~/.exo/models from SOURCE to one or more TARGETS using scp -3.
# Username defaults:
# - If host is "aN" and no user given, username defaults to "aN".
# - Otherwise defaults to $(whoami), unless you pass user@host.
#
# Examples:
# ./copy_model.sh a1 a2 a3
# ./copy_model.sh a1 frank@a2 192.168.1.3
if [ $# -lt 2 ]; then
echo "Usage: $0 SOURCE TARGET [TARGET...]" >&2
exit 2
fi
SOURCE="$1"
shift
TARGETS=("$@")
DEFAULT_USER="$(whoami)"
MODELS_REL=".exo/models" # relative under $HOME
timestamp() { date "+%Y-%m-%d %H:%M:%S"; }
split_user_host() {
local in="$1"
if [[ "$in" == *"@"* ]]; then
printf "%s|%s" "${in%%@*}" "${in#*@}"
else
printf "|%s" "$in"
fi
}
resolve_ip() {
local hostish="$1"
if [[ "$hostish" =~ ^a([0-9]+)$ ]]; then
echo "192.168.1.${BASH_REMATCH[1]}"
else
echo "$hostish"
fi
}
default_user_for() {
local hostish="$1"
if [[ "$hostish" =~ ^a([0-9]+)$ ]]; then
echo "$hostish"
else
echo "$DEFAULT_USER"
fi
}
SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR -o ConnectTimeout=10)
SSHPASS_BIN="$(command -v sshpass || true)"
SCP_BIN="${SCP_BIN:-scp}"
read -s -p "Password for all hosts: " PASS
echo
if [ -n "$SSHPASS_BIN" ]; then
echo "$(timestamp) sshpass found: will provide the password non-interactively."
else
echo "$(timestamp) WARNING: sshpass not found — youll be prompted by scp/ssh per hop unless keys are set up."
fi
# Build source endpoint (default username logic)
IFS='|' read -r SRC_USER_RAW SRC_HOSTISH <<<"$(split_user_host "$SOURCE")"
SRC_USER="${SRC_USER_RAW:-$(default_user_for "$SRC_HOSTISH")}"
SRC_IP="$(resolve_ip "$SRC_HOSTISH")"
SRC_HOST="${SRC_USER}@${SRC_IP}"
echo "$(timestamp) Source: ${SRC_HOST}:~/${MODELS_REL}"
echo "$(timestamp) Targets: ${#TARGETS[@]}"
# Helper to run a simple remote command via ssh (for mkdir -p checks)
ssh_run() {
local host="$1"
shift
if [ -n "$SSHPASS_BIN" ]; then
sshpass -p "$PASS" ssh "${SSH_OPTS[@]}" "$host" "$@"
else
ssh "${SSH_OPTS[@]}" "$host" "$@"
fi
}
# Ensure source dir exists (create if missing, per your request)
ssh_run "$SRC_HOST" "mkdir -p ~/${MODELS_REL}"
failures=0
count=0
for T in "${TARGETS[@]}"; do
count=$((count + 1))
IFS='|' read -r T_USER_RAW T_HOSTISH <<<"$(split_user_host "$T")"
T_USER="${T_USER_RAW:-$(default_user_for "$T_HOSTISH")}"
T_IP="$(resolve_ip "$T_HOSTISH")"
T_HOST="${T_USER}@${T_IP}"
echo "============================================================"
echo "$(timestamp) [${count}/${#TARGETS[@]}] ${SRC_HOST} ==> ${T_HOST}"
echo "$(timestamp) Ensuring destination directory exists…"
ssh_run "$T_HOST" "mkdir -p ~/${MODELS_REL%/*}" # ~/.exo
# Copy the whole "models" directory into ~/.exo on the target.
# scp -3 = copy between two remotes via local; -r recursive; -p preserve times/modes
if [ -n "$SSHPASS_BIN" ]; then
echo "$(timestamp) Running: scp -3 -rp ${SRC_HOST}:~/${MODELS_REL} ${T_HOST}:~/.exo/"
if sshpass -p "$PASS" "$SCP_BIN" "${SSH_OPTS[@]}" -3 -rp \
"${SRC_HOST}:~/${MODELS_REL}" \
"${T_HOST}:~/.exo/"; then
echo "$(timestamp) [${count}] Done: ${T_HOST}"
else
echo "$(timestamp) [${count}] ERROR during scp to ${T_HOST}" >&2
failures=$((failures + 1))
fi
else
echo "$(timestamp) Running: scp -3 -rp ${SRC_HOST}:~/${MODELS_REL} ${T_HOST}:~/.exo/"
if "$SCP_BIN" "${SSH_OPTS[@]}" -3 -rp \
"${SRC_HOST}:~/${MODELS_REL}" \
"${T_HOST}:~/.exo/"; then
echo "$(timestamp) [${count}] Done: ${T_HOST}"
else
echo "$(timestamp) [${count}] ERROR during scp to ${T_HOST}" >&2
failures=$((failures + 1))
fi
fi
done
echo "============================================================"
if [ "$failures" -eq 0 ]; then
echo "$(timestamp) All transfers completed successfully."
else
echo "$(timestamp) Completed with ${failures} failure(s)."
fi

View File

@@ -943,7 +943,7 @@
}
const result = await response.json();
showLaunchStatus(`Instance launched successfully: ${result.instance_id}`, 'success');
showLaunchStatus('Instance launched successfully');
// Reset form
modelSelect.value = '';

39
flake.lock generated
View File

@@ -1,5 +1,26 @@
{
"nodes": {
"fenix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1755585599,
"narHash": "sha256-tl/0cnsqB/Yt7DbaGMel2RLa7QG5elA8lkaOXli6VdY=",
"owner": "nix-community",
"repo": "fenix",
"rev": "6ed03ef4c8ec36d193c18e06b9ecddde78fb7e42",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "fenix",
"type": "github"
}
},
"flake-compat": {
"flake": false,
"locked": {
@@ -102,12 +123,30 @@
},
"root": {
"inputs": {
"fenix": "fenix",
"flake-parts": "flake-parts",
"flake-root": "flake-root",
"just-flake": "just-flake",
"make-shell": "make-shell",
"nixpkgs": "nixpkgs"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1755504847,
"narHash": "sha256-VX0B9hwhJypCGqncVVLC+SmeMVd/GAYbJZ0MiiUn2Pk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "a905e3b21b144d77e1b304e49f3264f6f8d4db75",
"type": "github"
},
"original": {
"owner": "rust-lang",
"ref": "nightly",
"repo": "rust-analyzer",
"type": "github"
}
}
},
"root": "root",

View File

@@ -20,47 +20,39 @@
# Provides flake integration with [Just](https://just.systems/man/en/)
just-flake.url = "github:juspay/just-flake";
# Provides Rust dev-env integration:
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
};
};
# TODO: figure out caching story
# nixConfig = {
# # nix community cachix
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
# extra-substituters = "https://nix-community.cachix.org";
# };
outputs =
inputs@{
flake-parts,
...
}:
flake-parts.lib.mkFlake { inherit inputs; } (
{
flake-parts-lib,
self,
...
}:
let
nixpkgs-lib = inputs.nixpkgs.lib;
# A wraper around importApply that supplies default parameters
importApply' =
path: extraParams:
(flake-parts-lib.importApply path (
nixpkgs-lib.recursiveUpdate {
localSelf = self;
inherit flake-parts-lib;
inherit nixpkgs-lib;
} extraParams
));
# instantiate all the flake modules, passing custom arguments to them as needed
flakeModules = {
flakeRoot = importApply' ./.flake-modules/flake-root.nix { inherit (inputs) flake-root; };
justFlake = importApply' ./.flake-modules/just-flake.nix { inherit (inputs) just-flake; };
goForwarder = importApply' ./.flake-modules/go-forwarder.nix { };
};
in
{ flake-parts-lib, self, ... }:
{
imports = [
inputs.make-shell.flakeModules.default
flakeModules.flakeRoot
flakeModules.justFlake
flakeModules.goForwarder
./.flake-modules/macmon.nix
./nix/modules/pkgs-init.nix # nixpkgs overlays manager
./nix/modules/flake-root.nix
./nix/modules/just-flake.nix
./nix/modules/macmon.nix
./nix/modules/python.nix
./nix/modules/rust.nix
./nix/modules/go-forwarder.nix
];
systems = [
"x86_64-linux"
@@ -75,55 +67,31 @@
system,
...
}:
let
buildInputs = with pkgs; [
];
nativeBuildInputs = with pkgs; [
];
in
{
# Per-system attributes can be defined here. The self' and inputs'
# module parameters provide easy access to attributes of the same
# system.
# NOTE: pkgs is equivalent to inputs'.nixpkgs.legacyPackages.hello;
apps = {
python-lsp = {
type = "app";
program = "${pkgs.basedpyright}/bin/basedpyright-langserver";
};
default = self'.apps.forwarder;
};
apps = { };
make-shells.default = {
packages = [
pkgs.python313
pkgs.uv
pkgs.protobuf
pkgs.basedpyright
pkgs.ruff
];
nativeBuildInputs =
with pkgs;
[
nixpkgs-fmt
cmake
]
++ buildInputs
++ nativeBuildInputs;
# Arguments which are intended to be environment variables in the shell environment
# should be changed to attributes of the `env` option
env = {
# fixes libstdc++.so issues and libgl.so issues
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib";
};
nativeBuildInputs = with pkgs; [
nixpkgs-fmt
];
shellHook = ''
export GO_BUILD_DIR=$(git rev-parse --show-toplevel)/build;
export DASHBOARD_DIR=$(git rev-parse --show-toplevel)/dashboard;
'';
# Arguments which are intended to be environment variables in the shell environment
# should be changed to attributes of the `env` option
env = { };
# Arbitrary mkDerivation arguments should be changed to be attributes of the `additionalArguments` option
additionalArguments = { };
};

65
kill_remote.sh Executable file
View File

@@ -0,0 +1,65 @@
#!/usr/bin/env bash
set -euo pipefail
###############################################################################
# Args & prerequisites
###############################################################################
if [[ $# -gt 1 ]]; then
echo "Usage: $0 [hosts_file]" >&2
exit 1
fi
HOSTS_FILE=${1:-hosts.txt}
###############################################################################
# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+)
###############################################################################
if [[ ! -f "$HOSTS_FILE" ]]; then
echo "Error: $HOSTS_FILE not found"
exit 1
fi
if builtin command -v mapfile >/dev/null 2>&1; then
mapfile -t HOSTS <"$HOSTS_FILE"
else
HOSTS=()
while IFS= read -r h; do
[[ -n "$h" ]] && HOSTS+=("$h")
done <"$HOSTS_FILE"
fi
[[ ${#HOSTS[@]} -gt 0 ]] || {
echo "No hosts found in $HOSTS_FILE"
exit 1
}
###############################################################################
# Helper run a remote command and capture rc/stderr/stdout
###############################################################################
ssh_opts=(-o StrictHostKeyChecking=no
-o LogLevel=ERROR)
run_remote() { # $1 host $2 command
local host=$1 cmd=$2 rc
if ssh "${ssh_opts[@]}" "$host" "$cmd"; then
rc=0
else
rc=$?
fi
return $rc
}
###############################################################################
# Kill exo everywhere (parallel)
###############################################################################
echo "=== Killing exo on ${#HOSTS[@]} host(s) ==="
fail=0
for h in "${HOSTS[@]}"; do
(
run_remote "$h" 'pkill -f exo || true'
) || fail=1 &
done
wait
((fail == 0)) || {
echo "❌ Some hosts could not be reached—check SSH access."
exit 1
}
echo "✓ exo processes killed on all reachable hosts."

View File

@@ -2,39 +2,14 @@
# 1. ${lib.getExe config.flake-root.package}
# 2. $FLAKE_ROOT environment-varible
# Top-level parameters that are bound to the provider flake
# These are passed from `flake.nix` using importApply
{
localSelf,
flake-parts-lib,
nixpkgs-lib,
flake-root,
...
}:
# These values would bind to the consumer flake when this flake module is imported:
{
config,
self,
inputs,
getSystem,
moduleWithSystem,
withSystem,
...
}:
{ inputs, ... }:
# The actual flake-parts module configuration
{
imports = [ flake-root.flakeModule ];
imports = [ inputs.flake-root.flakeModule ];
perSystem =
{
config,
self',
inputs',
pkgs,
system,
...
}:
{ config, ... }:
{
flake-root.projectRootFile = "flake.nix"; # Not necessary, as flake.nix is the default

View File

@@ -0,0 +1,19 @@
{
perSystem =
{
config,
pkgs,
lib,
...
}:
{
make-shells.default = {
# Go 1.24 compiler align with go.mod
packages = [ pkgs.go_1_24 ];
shellHook = ''
GOPATH="''$(${lib.getExe config.flake-root.package})"/.go_cache
export GOPATH
'';
};
};
}

View File

@@ -0,0 +1,26 @@
# Provides pretty banner & command index for this flake
{ inputs, ... }:
{
imports = [ inputs.just-flake.flakeModule ];
perSystem =
{ config, ... }:
{
just-flake.features = {
# treefmt.enable = true;
# rust.enable = true;
# convco.enable = true;
# hello = {
# enable = true;
# justfile = ''
# hello:
# echo Hello World
# '';
# };
};
make-shells.default = {
inputsFrom = [ config.just-flake.outputs.devShell ];
};
};
}

12
nix/modules/macmon.nix Normal file
View File

@@ -0,0 +1,12 @@
{
perSystem =
{ lib, pkgs, ... }:
lib.mkMerge [
(lib.mkIf pkgs.stdenv.isDarwin {
make-shells.default = {
packages = [ pkgs.macmon ];
};
})
];
}

62
nix/modules/pkgs-init.nix Normal file
View File

@@ -0,0 +1,62 @@
# Single module responsible for collecting all overlays and instantiating in one go
{
flake-parts-lib,
inputs,
self,
specialArgs,
...
}:
let
inherit (flake-parts-lib) mkPerSystemOption;
in
{
options.perSystem = mkPerSystemOption (
{
system,
config,
lib,
options,
pkgs,
self',
...
}@args:
let
inherit (lib.types)
attrsOf
listOf
submoduleWith
raw
;
in
{
options.pkgs-init.overlays = lib.mkOption {
description = ''
List of nixpkgs overlays (functions of the form: final: prev: { ... }).
Any module can append. Order matters.
'';
default = [ ];
example = [
(final: prev: {
my-hello = prev.hello;
})
];
type = lib.types.listOf lib.types.unspecified;
};
options.pkgs-init.importArgs = lib.mkOption {
description = "Extra arguments merged into the nixpkgs import call.";
default = { };
type = lib.types.attrs;
};
config = {
_module.args.pkgs = import inputs.nixpkgs (
{
inherit system;
overlays = config.pkgs-init.overlays;
}
// config.pkgs-init.importArgs
);
};
}
);
}

20
nix/modules/python.nix Normal file
View File

@@ -0,0 +1,20 @@
# Configures Python shell
{
perSystem =
{ pkgs, ... }:
{
make-shells.default = {
packages = [
pkgs.python313
pkgs.uv
pkgs.ruff
pkgs.basedpyright
];
shellHook = ''
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${pkgs.python313}/lib
'';
};
};
}

25
nix/modules/rust.nix Normal file
View File

@@ -0,0 +1,25 @@
# Configures Rust shell
{ inputs, ... }:
{
perSystem =
{ pkgs, ... }:
{
pkgs-init.overlays = [
inputs.fenix.overlays.default
];
make-shells.default = {
packages = [
(pkgs.fenix.complete.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
pkgs.rustup # literally only added to make RustRover happy (otherwise useless)
];
};
};
}

View File

@@ -33,6 +33,9 @@ dependencies = [
"cobs>=1.2.2",
"loguru>=0.7.3",
"textual>=5.3.0",
"exo_pyo3_bindings", # rust bindings
"anyio>=4.10.0",
"bidict>=0.23.1",
]
[project.scripts]
@@ -61,8 +64,12 @@ darwin = [
[tool.uv.workspace]
members = [
"scripts",
"rust/exo_pyo3_bindings",
]
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
[build-system]
requires = ["uv_build>=0.8.9,<0.9.0"]
build-backend = "uv_build"
@@ -87,7 +94,7 @@ reportUnnecessaryTypeIgnoreComment = "error"
pythonVersion = "3.13"
pythonPlatform = "Darwin"
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts"]
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"]
stubPath = "typings"
[[tool.basedpyright.executionEnvironments]]

View File

@@ -4,47 +4,49 @@ set -euo pipefail
###############################################################################
# Args & prerequisites
###############################################################################
if [[ $# -lt 2 ]]; then
echo "Usage: $0 <PASSWORD> <git_command> [git_args...]" >&2
if [[ $# -lt 1 ]]; then
echo "Usage: $0 <git_command> [git_args...]" >&2
echo "Examples:" >&2
echo " $0 mypassword pull" >&2
echo " $0 mypassword checkout main" >&2
echo " $0 mypassword status" >&2
echo " $0 mypassword fetch --all" >&2
echo " $0 pull" >&2
echo " $0 checkout main" >&2
echo " $0 status" >&2
echo " $0 fetch --all" >&2
exit 1
fi
PASSWORD=$1
shift # Remove password from args
GIT_CMD="$*" # Remaining args form the git command
HOSTS_FILE=${HOSTS_FILE:-hosts.json}
for prog in jq sshpass; do
command -v "$prog" >/dev/null ||
{ echo "Error: $prog not installed."; exit 1; }
done
GIT_CMD="$*" # All args form the git command
HOSTS_FILE=${HOSTS_FILE:-hosts.txt}
###############################################################################
# Load hosts.json (works on macOS Bash 3.2 and Bash 4+)
# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+)
###############################################################################
if [[ ! -f "$HOSTS_FILE" ]]; then
echo "Error: $HOSTS_FILE not found"
exit 1
fi
if builtin command -v mapfile >/dev/null 2>&1; then
mapfile -t HOSTS < <(jq -r '.[]' "$HOSTS_FILE")
mapfile -t HOSTS <"$HOSTS_FILE"
else
HOSTS=()
while IFS= read -r h; do HOSTS+=("$h"); done < <(jq -r '.[]' "$HOSTS_FILE")
while IFS= read -r h; do
[[ -n "$h" ]] && HOSTS+=("$h")
done <"$HOSTS_FILE"
fi
[[ ${#HOSTS[@]} -gt 0 ]] || { echo "No hosts found in $HOSTS_FILE"; exit 1; }
[[ ${#HOSTS[@]} -gt 0 ]] || {
echo "No hosts found in $HOSTS_FILE"
exit 1
}
###############################################################################
# Helper run a remote command and capture rc/stderr/stdout
###############################################################################
ssh_opts=(-o StrictHostKeyChecking=no
-o NumberOfPasswordPrompts=1 # allow sshpass to answer exactly once
-o LogLevel=ERROR)
-o LogLevel=ERROR)
run_remote () { # $1 host $2 command
run_remote() { # $1 host $2 command
local host=$1 cmd=$2 rc
if sshpass -p "$PASSWORD" ssh "${ssh_opts[@]}" "$host" "$cmd"; then
if ssh "${ssh_opts[@]}" "$host" "$cmd"; then
rc=0
else
rc=$?
@@ -72,9 +74,9 @@ done
wait
echo ""
if (( fail == 0 )); then
if ((fail == 0)); then
echo "🎉 Git command executed successfully on all hosts!"
else
echo "⚠️ Some hosts failed—see above."
exit 1
fi
fi

View File

@@ -4,38 +4,42 @@ set -euo pipefail
###############################################################################
# Args & prerequisites
###############################################################################
if [[ $# -lt 1 || $# -gt 2 ]]; then
echo "Usage: $0 <PASSWORD> [hosts_file]" >&2 ; exit 1
if [[ $# -gt 1 ]]; then
echo "Usage: $0 [hosts_file]" >&2
exit 1
fi
PASSWORD=$1
HOSTS_FILE=${2:-hosts.json}
for prog in jq sshpass; do
command -v "$prog" >/dev/null ||
{ echo "Error: $prog not installed."; exit 1; }
done
HOSTS_FILE=${1:-hosts.txt}
###############################################################################
# Load hosts.json (works on macOS Bash 3.2 and Bash 4+)
# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+)
###############################################################################
if [[ ! -f "$HOSTS_FILE" ]]; then
echo "Error: $HOSTS_FILE not found"
exit 1
fi
if builtin command -v mapfile >/dev/null 2>&1; then
mapfile -t HOSTS < <(jq -r '.[]' "$HOSTS_FILE")
mapfile -t HOSTS <"$HOSTS_FILE"
else
HOSTS=()
while IFS= read -r h; do HOSTS+=("$h"); done < <(jq -r '.[]' "$HOSTS_FILE")
while IFS= read -r h; do
[[ -n "$h" ]] && HOSTS+=("$h")
done <"$HOSTS_FILE"
fi
[[ ${#HOSTS[@]} -gt 0 ]] || { echo "No hosts found in $HOSTS_FILE"; exit 1; }
[[ ${#HOSTS[@]} -gt 0 ]] || {
echo "No hosts found in $HOSTS_FILE"
exit 1
}
###############################################################################
# Helper run a remote command and capture rc/stderr/stdout
###############################################################################
ssh_opts=(-o StrictHostKeyChecking=no
-o NumberOfPasswordPrompts=1 # allow sshpass to answer exactly once
-o LogLevel=ERROR)
-o LogLevel=ERROR)
run_remote () { # $1 host $2 command
run_remote() { # $1 host $2 command
local host=$1 cmd=$2 rc
if sshpass -p "$PASSWORD" ssh "${ssh_opts[@]}" "$host" "$cmd"; then
if ssh "${ssh_opts[@]}" "$host" "$cmd"; then
rc=0
else
rc=$?
@@ -54,26 +58,42 @@ for h in "${HOSTS[@]}"; do
) || fail=1 &
done
wait
(( fail == 0 )) || { echo "❌ Some hosts could not be reached—check password or SSH access."; exit 1; }
((fail == 0)) || {
echo "❌ Some hosts could not be reached—check SSH access."
exit 1
}
echo "✓ exo processes killed on all reachable hosts."
#
###############################################################################
# Phase 2 start new exo processes (parallel, with sudo -S)
# Phase 2 cleanup database files (parallel)
###############################################################################
echo "=== Stage 2: starting new exo processes ==="
echo "=== Stage 2: cleaning up database files ==="
fail=0
for i in "${!HOSTS[@]}"; do
h=${HOSTS[$i]}
# one liner that pre-caches sudo and then runs the script
if [[ $i -eq 0 ]]; then
remote_cmd="cd ~/exo && ./run.sh -c"
else
remote_cmd="cd ~/exo && ./run.sh -rc"
fi
( run_remote "$h" "$remote_cmd" ) || fail=1 &
for h in "${HOSTS[@]}"; do
(
run_remote "$h" 'rm -f ~/.exo/*db* || true'
) || fail=1 &
done
wait
(( fail == 0 )) && echo "🎉 Deployment finished!" || \
{ echo "⚠️ Some starts failed—see above."; exit 1; }
((fail == 0)) || {
echo " Some hosts failed database cleanup."
exit 1
}
echo "✓ Database files cleaned on all hosts."
###############################################################################
# Phase 3 start new exo processes in Terminal windows (parallel)
###############################################################################
echo "=== Stage 3: starting new exo processes ==="
fail=0
for h in "${HOSTS[@]}"; do
# Use osascript to open Terminal windows on remote Mac
remote_cmd="osascript -e \"tell app \\\"Terminal\\\" to do script \\\"cd ~/exo; nix develop --command uv run exo\\\"\""
(run_remote "$h" "$remote_cmd") || fail=1 &
done
wait
((fail == 0)) && echo "🎉 Deployment finished!" || {
echo "⚠️ Some starts failed—see above."
exit 1
}

15
rust/.gitignore vendored Normal file
View File

@@ -0,0 +1,15 @@
# Generated by Cargo
# will have compiled files and executables
debug
target
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
# Generated by cargo mutants
# Contains mutation testing data
**/mutants.out*/

165
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,165 @@
[workspace]
resolver = "3"
members = [
"networking",
"exo_pyo3_bindings",
"system_custodian",
"util",
]
[workspace.package]
version = "0.0.1"
edition = "2024"
[profile.dev]
opt-level = 1
debug = true
[profile.release]
opt-level = 3
# Common shared dependendencies configured once at the workspace
# level, to be re-used more easily across workspace member crates.
#
# Common configurations include versions, paths, features, etc.
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "networking" }
system_custodian = { path = "system_custodian" }
util = { path = "util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"
# networking
libp2p = "0.56"
libp2p-tcp = "0.44"
[workspace.lints.rust]
static_mut_refs = "warn" # Or use "warn" instead of deny
incomplete_features = "allow"
# Clippy's lint category level configurations;
# every member crate needs to inherit these by adding
#
# ```toml
# [lints]
# workspace = true
# ```
#
# to their `Cargo.toml` files
[workspace.lints.clippy]
# Clippy lint categories meant to be enabled all at once
correctness = { level = "deny", priority = -1 }
suspicious = { level = "warn", priority = -1 }
style = { level = "warn", priority = -1 }
complexity = { level = "warn", priority = -1 }
perf = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
nursery = { level = "warn", priority = -1 }
cargo = { level = "warn", priority = -1 }
# Individual Clippy lints from the `restriction` category
arithmetic_side_effects = "warn"
as_conversions = "warn"
assertions_on_result_states = "warn"
clone_on_ref_ptr = "warn"
decimal_literal_representation = "warn"
default_union_representation = "warn"
deref_by_slicing = "warn"
disallowed_script_idents = "deny"
else_if_without_else = "warn"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"
error_impl_error = "warn"
exit = "deny"
expect_used = "warn"
float_cmp_const = "warn"
get_unwrap = "warn"
if_then_some_else_none = "warn"
impl_trait_in_params = "warn"
indexing_slicing = "warn"
infinite_loop = "warn"
let_underscore_must_use = "warn"
let_underscore_untyped = "warn"
lossy_float_literal = "warn"
mem_forget = "warn"
missing_inline_in_public_items = "warn"
multiple_inherent_impl = "warn"
multiple_unsafe_ops_per_block = "warn"
mutex_atomic = "warn"
non_zero_suggestions = "warn"
panic = "warn"
partial_pub_fields = "warn"
pattern_type_mismatch = "warn"
pub_without_shorthand = "warn"
rc_buffer = "warn"
rc_mutex = "warn"
redundant_type_annotations = "warn"
renamed_function_params = "warn"
rest_pat_in_fully_bound_structs = "warn"
same_name_method = "warn"
self_named_module_files = "deny"
semicolon_inside_block = "warn"
shadow_same = "warn"
shadow_unrelated = "warn"
str_to_string = "warn"
string_add = "warn"
string_lit_chars_any = "warn"
string_to_string = "warn"
tests_outside_test_module = "warn"
todo = "warn"
try_err = "warn"
undocumented_unsafe_blocks = "warn"
unnecessary_safety_comment = "warn"
unnecessary_safety_doc = "warn"
unneeded_field_pattern = "warn"
unseparated_literal_suffix = "warn"
unused_result_ok = "warn"
unused_trait_names = "warn"
unwrap_used = "warn"
verbose_file_reads = "warn"

2
rust/clippy.toml Normal file
View File

@@ -0,0 +1,2 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -0,0 +1,77 @@
[package]
name = "exo_pyo3_bindings"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
path = "src/lib.rs"
name = "exo_pyo3_bindings"
# "cdylib" needed to produce shared library for Python to import
# "rlib" needed for stub-gen to run
crate-type = ["cdylib", "rlib"]
[[bin]]
path = "src/bin/stub_gen.rs"
name = "stub_gen"
doc = false
[lints]
workspace = true
[dependencies]
networking = { workspace = true }
# interop
pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!!
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.13.1" }
pyo3-async-runtimes = { version = "0.25", features = ["attributes", "tokio-runtime", "testing"] }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
pyo3-log = "0.12"
# Networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -0,0 +1 @@
TODO: do something here....

View File

@@ -0,0 +1,207 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401
import builtins
from enum import Enum
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
class Keypair:
r"""
Identity keypair of a node.
"""
@staticmethod
def generate_ed25519() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def generate_ecdsa() -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes:bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes:bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes:bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes:bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n:builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes:bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string:builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
"""
def to_string(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
"""
class NetworkingHandle:
def __new__(cls, identity:Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit:builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic:builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
"""
async def gossipsub_unsubscribe(self, topic:builtins.str) -> builtins.bool:
r"""
Unsubscribes from a `GossipSub` topic.
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
"""
async def gossipsub_publish(self, topic:builtins.str, data:bytes) -> None:
r"""
Publishes a message with multiple topics to the `GossipSub` network.
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit:builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *args) -> NoPeersSubscribedToTopicError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes:bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
class ConnectionUpdateType(Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...

View File

@@ -0,0 +1,32 @@
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[project]
name = "exo_pyo3_bindings"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
authors = [
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" }
]
requires-python = ">=3.13"
dependencies = []
[dependency-groups]
dev = [
"exo_pyo3_bindings",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
]
[tool.maturin]
#purelib = true
#python-source = "python"
module-name = "exo_pyo3_bindings"
features = ["pyo3/extension-module", "pyo3/experimental-async"]
[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
asyncio_mode = "auto"

View File

@@ -0,0 +1,40 @@
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
pin::{Pin, pin},
task::{Context, Poll},
};
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F);
impl<F> AllowThreads<F>
where
Self: Future,
{
pub fn new(f: F) -> Self {
Self(f)
}
}
impl<F> Future for AllowThreads<F>
where
F: Future + Ungil,
F::Output: Ungil,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
}
}

View File

@@ -0,0 +1,8 @@
use pyo3_stub_gen::Result;
fn main() -> Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
let stub = exo_pyo3_bindings::stub_info()?;
stub.generate()?;
Ok(())
}

View File

@@ -0,0 +1,240 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,217 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
/// Namespace for all the constants used by this crate.
pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind())
}
}
#[ext(pub, name = ResultExt)]
impl<T, E> Result<T, E>
where
E: ToString,
{
fn pyerr(self) -> PyResult<T> {
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
}
pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
fn allow_threads_py(self) -> AllowThreads<Self>
where
AllowThreads<Self>: Future,
{
AllowThreads::new(self)
}
}
impl<T: Future> FutureExt for T {}
#[ext(pub, name = PyErrExt)]
impl PyErr {
fn receiver_channel_closed() -> Self {
PyConnectionError::new_err("Receiver channel closed unexpectedly")
}
}
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
match self {
Ok(v) => Some(v),
Err(e) => {
// write error back to python
e.write_unraisable(py, None);
None
}
}
}
}
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule(name = "exo_pyo3_bindings")]
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
// top-level constructs
// TODO: ...
Ok(())
}
define_stub_info_gatherer!(stub_info);

View File

@@ -0,0 +1,534 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use tokio::sync::{Mutex, mpsc, oneshot};
use networking::discovery;
use networking::swarm::create_swarm;
use util::ext::VecExt as _;
mod exception {
use pyo3::{exceptions::{PyException}, prelude::*, PyErrArguments};
use pyo3::types::PyTuple;
use pyo3_stub_gen::{derive::*};
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
pub struct PyNoPeersSubscribedToTopicError {}
impl PyNoPeersSubscribedToTopicError {
const MSG: &'static str = "\
No peers are currently subscribed to receive messages on this topic. \
Wait for peers to subscribe or check your network connectivity.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNoPeersSubscribedToTopicError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use networking::swarm::BehaviourEvent::*;
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
}
}
}
log::info!("RUST: networking task stopped");
}
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
}
#[allow(clippy::expect_used)]
impl PyNetworkingHandle {
fn new(
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
) -> Self {
Self {
to_task_tx: Some(to_task_tx),
connection_update_rx: Mutex::new(connection_update_rx),
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
}
}
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNetworkingHandle {
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
// immediately beforehand to release the interpreter.
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
// ---- Lifecycle management methods ----
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity) })
.pyerr()?;
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
}
#[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
// ---- Connection update receiver methods ----
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
///
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
self.to_task_tx()
.send_py(ToTask::GossipsubSubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
}
/// Unsubscribes from a `GossipSub` topic.
///
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & convert any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
/// Publishes a message with multiple topics to the `GossipSub` network.
///
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,
data,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors => ignore messageID for now!!!
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
Ok(())
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.map(|(t, d)| (t, d.pybytes())))
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,159 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -0,0 +1,8 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -0,0 +1,81 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -0,0 +1,54 @@
#[cfg(test)]
mod tests {
use core::mem::drop;
use core::option::Option::Some;
use core::time::Duration;
use tokio;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_drop_channel() {
struct Ping;
let (tx, mut rx) = mpsc::channel::<Ping>(10);
let _ = tokio::spawn(async move {
println!("TASK: entered");
loop {
tokio::select! {
result = rx.recv() => {
match result {
Some(_) => {
println!("TASK: pinged");
}
None => {
println!("TASK: closing channel");
break;
}
}
}
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
println!("TASK: heartbeat");
}
}
}
println!("TASK: exited");
});
let tx2 = tx.clone();
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx.send(Ping).await.expect("Should not fail");
drop(tx);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx2.send(Ping).await.expect("Should not fail");
drop(tx2);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
}
}

View File

@@ -0,0 +1,34 @@
import asyncio
import pytest
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
@pytest.mark.asyncio
async def test_sleep_on_multiple_items() -> None:
print("PYTHON: starting handle")
h = NetworkingHandle(Keypair.generate_ed25519())
ct = asyncio.create_task(_await_cons(h))
mt = asyncio.create_task(_await_msg(h))
# sleep for 4 ticks
for i in range(4):
await asyncio.sleep(1)
try:
await h.gossipsub_publish("topic", b"somehting or other")
except NoPeersSubscribedToTopicError as e:
print("caught it", e)
async def _await_cons(h: NetworkingHandle):
while True:
c = await h.connection_update_recv()
print(f"PYTHON: connection update: {c}")
async def _await_msg(h: NetworkingHandle):
while True:
m = await h.gossipsub_recv()
print(f"PYTHON: message: {m}")

View File

@@ -0,0 +1,44 @@
[package]
name = "networking"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "networking"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -0,0 +1,74 @@
use futures::stream::StreamExt as _;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
#[tokio::main]
async fn main() {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
// Configure swarm
let mut swarm =
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
// Create a Gossipsub topic & subscribe
let topic = gossipsub::IdentTopic::new("test-net");
swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("Subscribing to topic failed");
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }
}
}
}
}

View File

@@ -0,0 +1,130 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use std::{
error::Error,
hash::{Hash},
};
use std::time::Duration;
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -0,0 +1,44 @@
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
https://stackoverflow.com/questions/17169298/af-packet-on-osx
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
https://www.unix.com/man_page/mojave/4/ip/
^OS X manpages for IP
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
^driver kit, system extensions & kexts for macOS
----
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
----> here is a guide on how to set up thunderbolt-ethernet on linux
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
^GPT discussion about making socket interface
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
^GPT discussion about accessing TB on MacOS low level interactions
--------------------------------
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
---------------------------------
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c

View File

@@ -0,0 +1,379 @@
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{dummy, CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm};
use libp2p::{Multiaddr, PeerId, identity, mdns};
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
use std::io;
use std::net::IpAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use util::wakerdeque::WakerDeque;
use crate::ext::MultiaddrExt;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
mod managed {
use std::io;
use std::time::Duration;
use libp2p::{identity, mdns, ping};
use libp2p::swarm::NetworkBehaviour;
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
ping: ping::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
mdns: mdns_behaviour(keypair)?,
ping: ping_behaviour(),
})
}
}
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
use mdns::{Config, tokio};
// mDNS config => enable IPv6
let mdns_config = Config {
ttl: MDNS_RECORD_TTL,
query_interval: MDNS_QUERY_INTERVAL,
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
..Default::default()
};
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
Ok(mdns_behaviour?)
}
fn ping_behaviour() -> ping::Behaviour {
ping::Behaviour::new(ping::Config::new().with_timeout(PING_TIMEOUT).with_interval(PING_INTERVAL))
}
}
/// Events for when a listening connection is truly established and truly closed.
#[derive(Debug, Clone)]
pub enum Event {
ConnectionEstablished {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
ConnectionClosed {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
}
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
///
/// The behaviour operates as such:
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
/// to the swarm.
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
/// immediately, and expired but connected peers are disconnected from immediately.
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
/// connected peers are disconnected from.
pub struct Behaviour {
// state-tracking for managed behaviors & mDNS-discovered peers
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
peer_id,
connection: CloseConnection::One(connection),
})
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
self.dial(p, ma.clone()); // always connect
// get peer's multi-addresses or insert if missing
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
continue;
};
// multiaddress should never already be present - else something has gone wrong
let is_new_addr = mas.insert(ma);
assert!(is_new_addr, "cannot discover a discovered peer");
}
}
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
// at this point, we *must* have the peer
let mas = self
.mdns_discovered
.get_mut(&p)
.expect("nonexistent peer cannot expire");
// at this point, we *must* have the multiaddress
let was_present = mas.remove(&ma);
assert!(was_present, "nonexistent multiaddress cannot expire");
// if empty, remove the peer-id entirely
if mas.is_empty() {
self.mdns_discovered.remove(&p);
}
}
}
fn on_connection_established(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out connected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
fn on_connection_closed(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out disconnected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
}
impl NetworkBehaviour for Behaviour {
type ConnectionHandler =
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
type ToSwarm = Event;
// simply delegate to underlying mDNS behaviour
delegate! {
to self.managed {
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
}
}
fn handle_established_inbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
)?,
))
}
#[allow(clippy::needless_question_mark)]
fn handle_established_outbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
addr: &Multiaddr,
role_override: Endpoint,
port_use: PortUse,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_outbound_connection(
connection_id,
peer,
addr,
role_override,
port_use,
)?,
))
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
event: THandlerOutEvent<Self>,
) {
match event {
Either::Left(ev) => libp2p::core::util::unreachable(ev),
Either::Right(ev) => self.managed.on_connection_handler_event(
peer_id,
connection_id,
ev,
),
}
}
// hook into these methods to drive behavior
fn on_swarm_event(&mut self, event: FromSwarm) {
self.managed.on_swarm_event(event); // let mDNS handle swarm events
// handle swarm events to update internal state:
match event {
FromSwarm::ConnectionEstablished(ConnectionEstablished {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection established event which is filtered correctly
self.on_connection_established(peer_id, connection_id, ip, port)
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
}
// since we are running TCP/IP transport layer, we are assuming that
// no address changes can occur, hence encountering one is a fatal error
FromSwarm::AddressChange(a) => {
unreachable!("unhandlable: address change encountered: {:?}", a)
}
_ => {}
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
// delegate to managed behaviors for any behaviors they need to perform
match self.managed.poll(cx) {
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
match e {
// handle discovered and expired events from mDNS
managed::BehaviourEvent::Mdns(e) => match e.clone() {
mdns::Event::Discovered(peers) => {
self.handle_mdns_discovered(peers);
}
mdns::Event::Expired(peers) => {
self.handle_mdns_expired(peers);
}
}
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
}
}
}
// since we just consumed an event, we should immediately wake just in case
// there are more events to come where that came from
cx.waker().wake_by_ref();
}
// forward any other mDNS event to the swarm or its connection handler(s)
Poll::Ready(e) => {
return Poll::Ready(
e.map_out(|_| unreachable!("events returning to swarm already handled"))
.map_in(Either::Right),
);
}
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}
// send out any pending events from our own service
if let Some(e) = self.pending_events.pop_front(cx) {
return Poll::Ready(e.map_in(Either::Left));
}
// wait for pending events
Poll::Pending
}
}

View File

@@ -0,0 +1,44 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -0,0 +1,64 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use std::net::IpAddr;
use extend::ext;
use libp2p::Multiaddr;
use libp2p::multiaddr::Protocol;
#[ext(pub, name = MultiaddrExt)]
impl Multiaddr {
/// If the multiaddress corresponds to a TCP address, extracts it
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
let mut ps = self.into_iter();
let ip = if let Some(p) = ps.next() {
match p {
Protocol::Ip4(ip) => IpAddr::V4(ip),
Protocol::Ip6(ip) => IpAddr::V6(ip),
_ => return None
}
} else {
return None;
};
let Some(Protocol::Tcp(port)) = ps.next() else {
return None;
};
Some((ip, port))
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -0,0 +1,133 @@
use crate::alias;
use crate::swarm::transport::tcp_transport;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
pub type Swarm = libp2p::Swarm<Behaviour>;
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
///
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
/// even be, and how to inject the right version into this config/initialization. E.g. should
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
}
mod transport {
use crate::alias;
use crate::swarm::NETWORK_VERSION;
use futures::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;
use libp2p::pnet::{PnetError, PnetOutput};
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
/// See [`pnet_upgrade`] for more.
const PNET_PRESHARED_KEY: [u8; 32] = Sha3_256::new()
.update(b"exo_discovery_network")
.update(NETWORK_VERSION)
.finalize();
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
/// also different-versioned instances of this same network.
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
async fn pnet_upgrade<TSocket>(
socket: TSocket,
_: impl Sized,
) -> Result<PnetOutput<TSocket>, PnetError>
where
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
use pnet::{PnetConfig, PreSharedKey};
PnetConfig::new(PreSharedKey::new(PNET_PRESHARED_KEY))
.handshake(socket)
.await
}
/// TCP/IP transport layer configuration.
pub fn tcp_transport(
keypair: &identity::Keypair,
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
use libp2p::{
core::upgrade::Version,
tcp::{Config, tokio},
};
// `TCP_NODELAY` enabled => avoid latency
let tcp_config = Config::default().nodelay(true);
// V1 + lazy flushing => 0-RTT negotiation
let upgrade_version = Version::V1Lazy;
// Noise is faster than TLS + we don't care much for security
let noise_config = noise::Config::new(keypair)?;
// Use default Yamux config for multiplexing
let yamux_config = yamux::Config::default();
// Create new Tokio-driven TCP/IP transport layer
let base_transport = tokio::Transport::new(tcp_config)
.and_then(pnet_upgrade)
.upgrade(upgrade_version)
.authenticate(noise_config)
.multiplex(yamux_config);
// Return boxed transport (to flatten complex type)
Ok(base_transport.boxed())
}
}
mod behaviour {
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
#[derive(NetworkBehaviour)]
pub struct Behaviour {
pub discovery: discovery::Behaviour,
pub gossipsub: gossipsub::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
Ok(Self {
discovery: discovery::Behaviour::new(keypair)?,
gossipsub: gossipsub_behaviour(keypair),
})
}
}
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
// build a gossipsub network behaviour
// => signed message authenticity + strict validation mode means the message-ID is
// automatically provided by gossipsub w/out needing to provide custom message-ID function
gossipsub::Behaviour::new(
MessageAuthenticity::Signed(keypair.clone()),
ConfigBuilder::default()
.validation_mode(ValidationMode::Strict)
.build()
.expect("the configuration should always be valid"),
)
.expect("creating gossipsub behavior should always work")
}
}

View File

@@ -0,0 +1,7 @@
// maybe this will hold test in the future...??
#[cfg(test)]
mod tests {
#[test]
fn does_nothing() {}
}

2
rust/rust-toolchain.toml Normal file
View File

@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"

View File

@@ -0,0 +1,47 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -0,0 +1,4 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -0,0 +1,69 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

25
rust/util/Cargo.toml Normal file
View File

@@ -0,0 +1,25 @@
[package]
name = "util"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "util"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
# macro dependencies
extend = { workspace = true }
# utility dependencies
thiserror = { workspace = true }
once_cell = { workspace = true }
internment = { workspace = true }
derive_more = { workspace = true }
bon = { workspace = true }
recursion = { workspace = true }

53
rust/util/src/lib.rs Normal file
View File

@@ -0,0 +1,53 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub mod nonempty;
pub mod wakerdeque;
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub mod ext {
use extend::ext;
#[ext(pub, name = BoxedSliceExt)]
impl<T> Box<[T]> {
#[inline]
fn map<B, F>(self, f: F) -> Box<[B]>
where
F: FnMut(T) -> B,
{
self.into_iter().map(f).collect()
}
}
#[ext(pub, name = VecExt)]
impl<T> Vec<T> {
#[inline]
fn map<B, F>(self, f: F) -> Vec<B>
where
F: FnMut(T) -> B,
{
self.into_iter().map(f).collect()
}
}
}

138
rust/util/src/nonempty.rs Normal file
View File

@@ -0,0 +1,138 @@
use std::slice::SliceIndex;
use std::{ops, slice};
use thiserror::Error;
#[derive(Error, Debug)]
#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")]
pub struct EmptySliceError;
/// A pointer to a non-empty fixed-size slice allocated on the heap.
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(transparent)]
pub struct NonemptyArray<T>(Box<[T]>);
#[allow(clippy::arbitrary_source_item_ordering)]
impl<T> NonemptyArray<T> {
#[inline]
pub fn singleton(value: T) -> Self {
Self(Box::new([value]))
}
#[allow(clippy::missing_errors_doc)]
#[inline]
pub fn try_from_boxed_slice<S: Into<Box<[T]>>>(
boxed_slice: S,
) -> Result<Self, EmptySliceError> {
let boxed_slice = boxed_slice.into();
if boxed_slice.is_empty() {
Err(EmptySliceError)
} else {
Ok(Self(boxed_slice))
}
}
#[must_use]
#[inline]
pub fn into_boxed_slice(self) -> Box<[T]> {
self.0
}
#[must_use]
#[inline]
pub fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
self.0.to_vec()
}
#[must_use]
#[inline]
pub const fn as_slice(&self) -> &[T] {
&self.0
}
#[allow(clippy::indexing_slicing)]
#[must_use]
#[inline]
pub fn first(&self) -> &T {
&self.0[0]
}
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
#[must_use]
#[inline]
pub fn last(&self) -> &T {
&self.0[self.0.len() - 1]
}
#[must_use]
#[inline]
pub fn get<I>(&self, index: I) -> Option<&I::Output>
where
I: SliceIndex<[T]>,
{
self.0.get(index)
}
#[allow(clippy::len_without_is_empty)]
#[must_use]
#[inline]
pub const fn len(&self) -> usize {
self.0.len()
}
#[allow(clippy::iter_without_into_iter)]
#[inline]
pub fn iter(&self) -> slice::Iter<'_, T> {
self.0.iter()
}
#[allow(clippy::iter_without_into_iter)]
#[inline]
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
self.0.iter_mut()
}
#[inline]
#[must_use]
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> NonemptyArray<U> {
NonemptyArray(self.0.into_iter().map(f).collect())
}
}
impl<T> From<NonemptyArray<T>> for Box<[T]> {
#[inline]
fn from(value: NonemptyArray<T>) -> Self {
value.into_boxed_slice()
}
}
impl<T> ops::Index<usize> for NonemptyArray<T> {
type Output = T;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
self.0.index(index)
}
}
impl<T> IntoIterator for NonemptyArray<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.into_boxed_slice().into_vec().into_iter()
}
}
impl<'a, T> IntoIterator for &'a NonemptyArray<T> {
type Item = &'a T;
type IntoIter = slice::Iter<'a, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

View File

@@ -0,0 +1,55 @@
use std::collections::VecDeque;
use std::fmt::{Debug, Formatter};
use std::task::{Context, Waker};
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
pub struct WakerDeque<T> {
waker: Option<Waker>,
deque: VecDeque<T>,
}
impl<T: Debug> Debug for WakerDeque<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.deque.fmt(f)
}
}
impl<T> WakerDeque<T> {
pub fn new() -> Self {
Self {
waker: None,
deque: VecDeque::new(),
}
}
fn update(&mut self, cx: &mut Context<'_>) {
self.waker = Some(cx.waker().clone());
}
fn wake(&mut self) {
let Some(ref mut w) = self.waker else { return };
w.wake_by_ref();
self.waker = None;
}
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_front()
}
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_back()
}
pub fn push_front(&mut self, value: T) {
self.wake();
self.deque.push_front(value);
}
pub fn push_back(&mut self, value: T) {
self.wake();
self.deque.push_back(value);
}
}

View File

@@ -8,9 +8,10 @@ import mlx.nn as nn # type: ignore
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
def __call__(self, x: mx.array, cache: Optional[list[KVCache]]) -> mx.array: ...
@@ -18,7 +19,7 @@ class Detokenizer:
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
@property
def last_segment(self) -> str: ...
@@ -27,5 +28,5 @@ class TokenizerWrapper:
bos_token: Optional[str]
eos_token_ids: list[int]
detokenizer: Detokenizer
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...

View File

@@ -29,6 +29,7 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
mlx_rank: None | int = None
mlx_world_size: None | int = None
def mx_barrier():
mx.eval( # type: ignore
mx.distributed.all_sum(
@@ -36,6 +37,7 @@ def mx_barrier():
)
)
def broadcast_from_zero(value: int) -> int:
if mlx_rank is None:
return value
@@ -46,8 +48,9 @@ def broadcast_from_zero(value: int) -> int:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu))
mx.eval(m) # type: ignore
return int(m.item()) # type: ignore
mx.eval(m) # type: ignore
return int(m.item()) # type: ignore
class HostList(RootModel[list[str]]):
@classmethod
@@ -83,7 +86,7 @@ def mlx_setup(
if wired_frac_of_mrwss > 0.0:
target_wired = int(wired_frac_of_mrwss * mrwss)
target_wired = min(target_wired, target_cache) # dont wire more than cache
runner_print(f"{target_wired=}")
with contextlib.suppress(Exception): # older macOS wont have this
mx.set_wired_limit(max(target_wired, 0))
@@ -136,14 +139,14 @@ def initialize_mlx(
def shard_and_load(
model_shard_meta: ShardMetadata,
model_shard_meta: ShardMetadata,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(model_shard_meta.model_meta.model_id)
runner_print(f"loading model from {model_path}")
model, config = load_model(model_path, lazy=True, strict=False) # type: ignore
runner_print(f'{config=}')
runner_print(f"{config=}")
assert isinstance(model, nn.Module)
tokenizer = load_tokenizer(model_path)
@@ -154,7 +157,7 @@ def shard_and_load(
# Synchronize processes before generation to avoid timeout
mx_barrier()
return model, tokenizer # type: ignore
return model, tokenizer # type: ignore
async def apply_chat_template(
@@ -199,11 +202,13 @@ async def apply_chat_template(
return prompt
class NullKVCache(KVCache):
"""
A KVCache that pretends to exist but holds zero tokens.
It satisfies .state/.meta_state and never allocates real keys/values.
"""
def __init__(self, dtype: mx.Dtype = mx.float16):
super().__init__()
# zero-length K/V so shapes/dtypes are defined but empty
@@ -218,19 +223,21 @@ class NullKVCache(KVCache):
@state.setter
def state(self, v: tuple[mx.array, mx.array]) -> None:
raise NotImplementedError('We should not be setting a NullKVCache.')
raise NotImplementedError("We should not be setting a NullKVCache.")
async def make_kv_cache(
model: Model,
max_kv_size: Optional[int] = None,
) -> list[KVCache]:
assert hasattr(model, 'layers')
assert hasattr(model, "layers")
return [
NullKVCache() if isinstance(layer, IdentityLayer) else KVCache()
for layer in model.layers
]
def mlx_force_oom(size: int = 40000) -> None:
"""
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.

View File

@@ -1,41 +1,221 @@
import argparse
import multiprocessing as mp
from dataclasses import dataclass
from typing import Self
import anyio
from anyio.abc import TaskGroup
from loguru import logger
from pydantic import PositiveInt
from exo.master.main import main as master_main
import exo.routing.topics as topics
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
from exo.shared.logging import logger_cleanup, logger_setup
from exo.worker.main import main as worker_main
from exo.shared.types.common import NodeId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
# TODO: Entrypoint refactor
# I marked this as a dataclass as I want trivial constructors.
# This is the collection of systems for our entire application.
@dataclass
class Node:
router: Router
worker: Worker
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
master: Master | None
api: API | None
node_id: NodeId
_tg: TaskGroup | None = None
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS)
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
logger.info(f"Starting node {node_id}")
if args.spawn_api:
api = API(
node_id=node_id,
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
else:
api = None
worker = Worker(
node_id,
exo_shard_downloader(),
initial_connection_messages=[],
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
# We start every node with a master
master = Master(
node_id,
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
tb_only=args.tb_only,
)
# If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ.
er_send, er_recv = channel[ElectionResult]()
election = Election(
node_id,
seniority=1_000_000 if args.force_master else 0,
# nb: this DOES feedback right now. i have thoughts on how to address this,
# but ultimately it seems not worth the complexity
election_message_sender=router.sender(topics.ELECTION_MESSAGES),
election_message_receiver=router.receiver(topics.ELECTION_MESSAGES),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
election_result_sender=er_send,
)
return cls(router, worker, election, er_recv, master, api, node_id)
async def run(self):
async with anyio.create_task_group() as tg:
self._tg = tg
tg.start_soon(self.router.run)
tg.start_soon(self.worker.run)
tg.start_soon(self.election.run)
if self.master:
tg.start_soon(self.master.run)
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
async def _elect_loop(self):
assert self._tg
with self.election_result_receiver as results:
async for result in results:
# I don't like this duplication, but it's manageable for now.
# TODO: This function needs refactoring generally
# Ok:
# On new master:
# - Elect master locally if necessary
# - Shutdown and re-create the worker
# - Shut down and re-create the API
if result.node_id == self.node_id and self.master is not None:
logger.info("Node elected Master")
elif result.node_id == self.node_id and self.master is None:
logger.info("Node elected Master - promoting self")
self.master = Master(
self.node_id,
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
)
self._tg.start_soon(self.master.run)
elif result.node_id != self.node_id and self.master is not None:
logger.info(f"Node {result.node_id} elected master - demoting self")
await self.master.shutdown()
self.master = None
else:
logger.info(f"Node {result.node_id} elected master")
if result.is_new_master:
await anyio.sleep(0)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
exo_shard_downloader(),
initial_connection_messages=result.historic_messages,
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
)
self._tg.start_soon(self.worker.run)
if self.api:
self.api.reset()
def main():
parser = argparse.ArgumentParser(prog="exo")
parser.add_argument(
"-v", "--verbose", action="store_const", const=1, dest="verbosity", default=0
)
parser.add_argument(
"-vv",
"--very-verbose",
action="store_const",
const=2,
dest="verbosity",
default=0,
)
args = parser.parse_args()
if type(args.verbosity) is not int: # type: ignore
raise TypeError("Verbosity was parsed incorrectly")
args = Args.parse()
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("starting exo")
logger.info("Starting EXO")
# This is for future PyInstaller compatibility
mp.set_start_method("spawn", force=True)
worker = mp.Process(target=worker_main, args=(EXO_LOG, args.verbosity))
master = mp.Process(target=master_main, args=(EXO_LOG, args.verbosity))
worker.start()
master.start()
worker.join()
master.join()
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger_cleanup()
class Args(CamelCaseModel):
verbosity: int = 0
force_master: bool = False
spawn_api: bool = False
api_port: PositiveInt = 8000
tb_only: bool = False
@classmethod
def parse(cls) -> Self:
parser = argparse.ArgumentParser(prog="EXO")
default_verbosity = 0
parser.add_argument(
"-q",
"--quiet",
action="store_const",
const=-1,
dest="verbosity",
default=default_verbosity,
)
parser.add_argument(
"-v",
"--verbose",
action="count",
dest="verbosity",
default=default_verbosity,
)
parser.add_argument(
"-m",
"--force-master",
action="store_true",
dest="force_master",
)
parser.add_argument(
"--no-api",
action="store_false",
dest="spawn_api",
)
parser.add_argument(
"--api-port",
type=int,
dest="api_port",
default=8000,
)
parser.add_argument(
"--tb-only",
action="store_true",
dest="tb_only",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -2,16 +2,18 @@ import asyncio
import os
import time
from collections.abc import AsyncGenerator
from typing import Callable, List, Sequence, final
from typing import final
import uvicorn
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from loguru import logger
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
from exo.shared.apply import apply
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
@@ -24,23 +26,26 @@ from exo.shared.types.api import (
ModelListModel,
StreamingChoiceResponse,
)
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.events.chunks import TokenChunk
from exo.shared.types.events.commands import (
ChatCompletionCommand,
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CommandType,
CreateInstanceCommand,
DeleteInstanceCommand,
TaskFinishedCommand,
CreateInstance,
DeleteInstance,
ForwarderCommand,
TaggedCommand,
# TODO: SpinUpInstance
TaskFinished,
)
from exo.shared.types.events.components import EventFromEventLog
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.models import ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import Instance
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import OrderedBuffer
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
@@ -70,26 +75,50 @@ async def resolve_model_meta(model_id: str) -> ModelMetadata:
class API:
def __init__(
self,
command_buffer: List[Command],
global_events: AsyncSQLiteEventStorage,
get_state: Callable[[], State],
*,
node_id: NodeId,
port: int = 8000,
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
) -> None:
self.get_state = get_state
self.command_buffer = command_buffer
self.global_events = global_events
self.state = State()
self.command_sender = command_sender
self.global_event_receiver = global_event_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.port = port
self._app = FastAPI()
self.app = FastAPI()
self._setup_cors()
self._setup_routes()
self._app.mount(
self.app.mount(
"/",
StaticFiles(directory=os.environ["DASHBOARD_DIR"], html=True),
StaticFiles(
directory=os.environ.get(
"DASHBOARD_DIR",
os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../dashboard")
),
),
html=True,
),
name="dashboard",
)
self._chat_completion_queues: dict[
CommandId, asyncio.Queue[ChunkGenerated]
] = {}
self._tg: TaskGroup | None = None
def reset(self):
self.state = State()
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
def _setup_cors(self) -> None:
self._app.add_middleware(
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
@@ -98,23 +127,19 @@ class API:
)
def _setup_routes(self) -> None:
self._app.post("/instance")(self.create_instance)
self._app.get("/instance/{instance_id}")(self.get_instance)
self._app.delete("/instance/{instance_id}")(self.delete_instance)
self._app.get("/models")(self.get_models)
self._app.get("/v1/models")(self.get_models)
self._app.post("/v1/chat/completions")(self.chat_completions)
self._app.get("/state")(self.get_state)
@property
def app(self) -> FastAPI:
return self._app
self.app.post("/instance")(self.create_instance)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions")(self.chat_completions)
self.app.get("/state")(lambda: self.state)
async def create_instance(
self, payload: CreateInstanceTaskParams
) -> CreateInstanceResponse:
model_meta = await resolve_model_meta(payload.model_id)
required_memory_bytes = model_meta.storage_size_kilobytes * 1024
required_memory_bytes = model_meta.storage_size.in_kb
available_memory_bytes = self._calculate_total_available_memory()
if required_memory_bytes > available_memory_bytes:
@@ -123,37 +148,33 @@ class API:
detail=f"Insufficient memory to create instance. Required: {required_memory_bytes // (1024**3):.1f}GB, Available: {available_memory_bytes // (1024**3):.1f}GB",
)
command = CreateInstanceCommand(
command = CreateInstance(
command_id=CommandId(),
command_type=CommandType.CREATE_INSTANCE,
model_meta=model_meta,
instance_id=InstanceId(),
)
self.command_buffer.append(command)
await self._send(command)
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
instance_id=command.instance_id,
)
def get_instance(self, instance_id: InstanceId) -> Instance:
state = self.get_state()
state = self.state
if instance_id not in state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
return state.instances[instance_id]
def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
if instance_id not in self.get_state().instances:
async def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
command = DeleteInstanceCommand(
command = DeleteInstance(
command_id=CommandId(),
command_type=CommandType.DELETE_INSTANCE,
instance_id=instance_id,
)
self.command_buffer.append(command)
await self._send(command)
return DeleteInstanceResponse(
message="Command received.",
command_id=command.command_id,
@@ -165,37 +186,27 @@ class API:
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
events = await self.global_events.get_events_since(0)
prev_idx = await self.global_events.get_last_idx()
self._chat_completion_queues[command_id] = asyncio.Queue()
finished = False
while not finished:
await asyncio.sleep(0.01)
# TODO: how long should this timeout be?
chunk = await asyncio.wait_for(
self._chat_completion_queues[command_id].get(), timeout=60
)
if chunk.command_id == command_id:
assert isinstance(chunk.chunk, TokenChunk)
chunk_response: ChatCompletionResponse = chunk_to_response(chunk.chunk)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
events: Sequence[
EventFromEventLog[Event]
] = await self.global_events.get_events_since(prev_idx)
# TODO: Can do this with some better functionality to tail event log into an AsyncGenerator.
prev_idx = events[-1].idx_in_log if events else prev_idx
if chunk.chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
finished = True
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, ChunkGenerated) and event.command_id == command_id:
assert isinstance(event.chunk, TokenChunk)
chunk_response: ChatCompletionResponse = chunk_to_response(
event.chunk
)
logger.debug(chunk_response)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if event.chunk.finish_reason is not None:
yield "data: [DONE]"
finished = True
command = TaskFinishedCommand(command_id=command_id)
self.command_buffer.append(command)
return
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
del self._chat_completion_queues[command_id]
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
@@ -210,6 +221,7 @@ class API:
payload.model = model_meta.model_id
# Preprocess messages for GPT-OSS harmony format if needed
# TODO: This is slop surely we get rid
if "gpt-oss" in payload.model.lower():
import re
@@ -233,7 +245,7 @@ class API:
# Store thinking in the thinking field
message.thinking = thinking_match.group(1).strip()
for instance in self.get_state().instances.values():
for instance in self.state.instances.values():
if instance.shard_assignments.model_id == payload.model:
break
else:
@@ -242,23 +254,22 @@ class API:
status_code=404, detail=f"No instance found for model {payload.model}"
)
command = ChatCompletionCommand(
command = ChatCompletion(
command_id=CommandId(),
command_type=CommandType.CHAT_COMPLETION,
request_params=payload,
)
self.command_buffer.append(command)
await self._send(command)
return StreamingResponse(
self._generate_chat_stream(command.command_id), media_type="text/plain"
)
def _calculate_total_available_memory(self) -> int:
"""Calculate total available memory across all nodes in bytes."""
state = self.get_state()
total_available = 0
for node_profile in state.node_profiles.values():
total_available += node_profile.memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available.in_bytes
return total_available
@@ -277,14 +288,35 @@ class API:
]
)
async def run(self):
uvicorn_config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.port, access_log=False
)
uvicorn_server = uvicorn.Server(uvicorn_config)
def start_fastapi_server(
command_buffer: List[Command],
global_events: AsyncSQLiteEventStorage,
get_state: Callable[[], State],
host: str = "0.0.0.0",
port: int = 8000,
):
api = API(command_buffer, global_events, get_state)
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(uvicorn_server.serve)
tg.start_soon(self._apply_state)
self.command_sender.close()
self.global_event_receiver.close()
uvicorn.run(api.app, host=host, port=port)
async def _apply_state(self):
with self.global_event_receiver as events:
async for event in events:
self.event_buffer.ingest(event.origin_idx, event.tagged_event.c)
for idx, event in self.event_buffer.drain_indexed():
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
self._chat_completion_queues[event.command_id].put_nowait(event)
async def _send(self, command: Command):
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id, tagged_command=TaggedCommand.from_(command)
)
)

View File

@@ -1,23 +0,0 @@
from loguru import logger
from exo.master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
class ElectionCallbacks:
"""
Simple callbacks for the Rust election system to invoke.
No event system involvement - just direct forwarder control.
"""
def __init__(self, forwarder_supervisor: ForwarderSupervisor):
self._forwarder_supervisor = forwarder_supervisor
async def on_became_master(self) -> None:
"""Called when this node is elected as master"""
logger.info("Node elected as master")
await self._forwarder_supervisor.notify_role_change(ForwarderRole.MASTER)
async def on_became_replica(self) -> None:
"""Called when this node becomes a replica"""
logger.info("Node demoted to replica")
await self._forwarder_supervisor.notify_role_change(ForwarderRole.REPLICA)

View File

@@ -1,9 +0,0 @@
from pathlib import Path
from exo.shared.env import BaseEnv
class MasterEnvironmentSchema(BaseEnv):
# Master-specific: forwarder configuration
# Default to build/forwarder if not explicitly set
FORWARDER_BINARY_PATH: Path = Path("build/forwarder")

View File

@@ -10,7 +10,7 @@ from exo.shared.constants import (
EXO_GLOBAL_EVENT_DB,
EXO_WORKER_EVENT_DB,
LIBP2P_GLOBAL_EVENTS_TOPIC,
LIBP2P_WORKER_EVENTS_TOPIC,
LIBP2P_LOCAL_EVENTS_TOPIC,
)
from exo.shared.types.common import NodeId
@@ -58,9 +58,7 @@ class ForwarderSupervisor:
if self._current_role == new_role:
logger.debug(f"Role unchanged: {new_role}")
return
logger.bind(user_facing=True).info(
f"Node changing from {self._current_role} to {new_role}"
)
logger.info(f"Node changing from {self._current_role} to {new_role}")
self._current_role = new_role
await self._restart_with_role(new_role)
@@ -82,13 +80,13 @@ class ForwarderSupervisor:
# Both master and replica forward local worker events to network
pairs.append(
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}"
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}"
)
if role == ForwarderRole.MASTER:
# Master: collect worker events from network into global log
pairs.append(
f"libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events"
f"libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events"
)
# Master: broadcast global events to network
pairs.append(

View File

@@ -1,272 +1,240 @@
import asyncio
import os
import threading
from pathlib import Path
from anyio import create_task_group
from anyio.abc import TaskGroup
from loguru import logger
from exo.master.api import start_fastapi_server
from exo.master.election_callback import ElectionCallbacks
from exo.master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
from exo.master.placement import get_instance_placements, get_transition_events
from exo.master.placement import (
get_instance_placements_after_create,
get_instance_placements_after_delete,
get_transition_events,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_MASTER_LOG
from exo.shared.db.sqlite.config import EventLogConfig
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
from exo.shared.db.sqlite.event_log_manager import EventLogManager
from exo.shared.keypair import Keypair, get_node_id_keypair
from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.commands import (
ChatCompletion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
RequestEventLog,
SpinUpInstance,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import (
Event,
Heartbeat,
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
TaggedEvent,
TaskCreated,
TaskDeleted,
TopologyEdgeDeleted,
TopologyNodeCreated,
)
from exo.shared.types.events.commands import (
ChatCompletionCommand,
Command,
CreateInstanceCommand,
DeleteInstanceCommand,
TaskFinishedCommand,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
from exo.shared.types.worker.instances import Instance
from exo.shared.types.worker.common import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
class Master:
def __init__(
self,
node_id_keypair: Keypair,
node_id: NodeId,
command_buffer: list[Command],
global_events: AsyncSQLiteEventStorage,
worker_events: AsyncSQLiteEventStorage,
forwarder_binary_path: Path,
*,
command_receiver: Receiver[ForwarderCommand],
# Receiving indexed events from the forwarder to be applied to state
# Ideally these would be WorkerForwarderEvents but type system says no :(
local_event_receiver: Receiver[ForwarderEvent],
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
tb_only: bool = False,
):
self.state = State()
self.node_id_keypair = node_id_keypair
self._tg: TaskGroup | None = None
self.node_id = node_id
self.command_buffer = command_buffer
self.global_events = global_events
self.worker_events = worker_events
self.command_task_mapping: dict[CommandId, TaskId] = {}
self.forwarder_supervisor = ForwarderSupervisor(
self.node_id,
forwarder_binary_path=forwarder_binary_path,
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self.election_callbacks = ElectionCallbacks(self.forwarder_supervisor)
@property
def event_log_for_reads(self) -> AsyncSQLiteEventStorage:
return self.global_events
@property
def event_log_for_writes(self) -> AsyncSQLiteEventStorage:
if self.forwarder_supervisor.current_role == ForwarderRole.MASTER:
return self.global_events
else:
return self.worker_events
async def _get_state_snapshot(self) -> State:
# TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events
return State()
async def _run_event_loop_body(self) -> None:
next_events: list[Event] = []
# 1. process commands
if (
self.forwarder_supervisor.current_role == ForwarderRole.MASTER
and len(self.command_buffer) > 0
):
# for now we do one command at a time
next_command = self.command_buffer.pop(0)
logger.bind(user_facing=True).info(f"Executing command: {next_command}")
logger.info(f"Got command: {next_command}")
# TODO: validate the command
match next_command:
case ChatCompletionCommand():
matching_instance: Instance | None = None
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== next_command.request_params.model
):
matching_instance = instance
break
if not matching_instance:
raise ValueError(
f"No instance found for model {next_command.request_params.model}"
)
task_id = TaskId()
next_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_type=TaskType.CHAT_COMPLETION,
task_id=task_id,
command_id=next_command.command_id,
instance_id=matching_instance.instance_id,
task_status=TaskStatus.PENDING,
task_params=next_command.request_params,
),
)
)
self.command_task_mapping[next_command.command_id] = task_id
case DeleteInstanceCommand():
placement = get_instance_placements(
next_command, self.state.topology, self.state.instances
)
transition_events = get_transition_events(
self.state.instances, placement
)
next_events.extend(transition_events)
case CreateInstanceCommand():
placement = get_instance_placements(
next_command, self.state.topology, self.state.instances
)
transition_events = get_transition_events(
self.state.instances, placement
)
next_events.extend(transition_events)
case TaskFinishedCommand():
next_events.append(
TaskDeleted(
task_id=self.command_task_mapping[next_command.command_id]
)
)
del self.command_task_mapping[next_command.command_id]
await self.event_log_for_writes.append_events(
next_events, origin=self.node_id
)
# 2. get latest events
events = await self.event_log_for_reads.get_events_since(
self.state.last_event_applied_idx, ignore_no_op_events=True
)
if len(events) == 0:
await asyncio.sleep(0.01)
return
if len(events) == 1:
logger.debug(f"Master received event: {events[0]}")
else:
logger.debug(f"Master received events: {events}")
# 3. for each event, apply it to the state
for event_from_log in events:
logger.trace(f"Applying event: {event_from_log}")
self.state = apply(self.state, event_from_log)
logger.trace(f"State: {self.state.model_dump_json()}")
# TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken.
write_events: list[Event] = []
if any(
[
isinstance(event_from_log.event, TopologyEdgeDeleted)
for event_from_log in events
]
):
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
delete = False
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
delete = True
break
if delete:
write_events.append(InstanceDeleted(instance_id=instance_id))
if write_events:
await self.event_log_for_writes.append_events(
events=write_events, origin=self.node_id
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
self.tb_only = tb_only
async def run(self):
self.state = await self._get_state_snapshot()
logger.info("Starting Master")
async def heartbeat_task():
while True:
await self.event_log_for_writes.append_events(
[Heartbeat(node_id=self.node_id)], origin=self.node_id
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
if self._tg:
logger.info("Stopping Master")
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.command_receiver as commands:
async for forwarder_command in commands:
try:
logger.info(
f"Executing command: {forwarder_command.tagged_command.c}"
)
generated_events: list[Event] = []
command = forwarder_command.tagged_command.c
match command:
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
logger.warning(
f"No instance found for model {command.request_params.model}"
)
continue
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_type=TaskType.CHAT_COMPLETION,
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.PENDING,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = get_instance_placements_after_delete(
command, self.state.instances
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = get_instance_placements_after_create(
command,
self.state.topology,
self.state.instances,
tb_only=self.tb_only,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case SpinUpInstance():
raise NotImplementedError
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
for event in generated_events:
await self.event_sender.send(event)
except Exception as e:
logger.opt(exception=e).warning("Error in command processor")
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
async for local_event in local_events:
self._multi_buffer.ingest(
local_event.origin_idx,
local_event.tagged_event.c,
local_event.origin,
)
await asyncio.sleep(5)
for event in self._multi_buffer.drain():
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
# TODO: SQL
self._event_log.append(event)
await self._send_event(indexed)
asyncio.create_task(heartbeat_task())
# TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken.
if isinstance(event, TopologyEdgeDeleted):
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
await self.event_sender.send(
InstanceDeleted(instance_id=instance_id)
)
break
# TODO: we should clean these up on shutdown
await self.forwarder_supervisor.start_as_replica()
if os.getenv("EXO_RUN_AS_REPLICA") in set(["TRUE", "true", "1"]):
await self.election_callbacks.on_became_replica()
else:
await self.election_callbacks.on_became_master()
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
tagged_event=TaggedEvent.from_(event),
)
)
local_index += 1
role = (
"MASTER"
if self.forwarder_supervisor.current_role == ForwarderRole.MASTER
else "REPLICA"
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
origin=self.node_id,
origin_idx=event.idx,
tagged_event=TaggedEvent.from_(event.event),
)
)
await self.event_log_for_writes.append_events(
[TopologyNodeCreated(node_id=self.node_id, role=role)], origin=self.node_id
)
while True:
try:
await self._run_event_loop_body()
except Exception as e:
logger.opt(exception=e).error(f"Error in _run_event_loop_body: {e}")
await asyncio.sleep(0.1)
async def async_main():
node_id_keypair = get_node_id_keypair()
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
event_log_manager = EventLogManager(EventLogConfig())
await event_log_manager.initialize()
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
worker_events: AsyncSQLiteEventStorage = event_log_manager.worker_events
command_buffer: list[Command] = []
logger.info("Starting EXO Master")
logger.info(f"Starting Master with node_id: {node_id}")
api_port = int(os.environ.get("API_PORT", 8000))
api_thread = threading.Thread(
target=start_fastapi_server,
args=(command_buffer, global_events, lambda: master.state, "0.0.0.0", api_port),
daemon=True,
)
api_thread.start()
logger.bind(user_facing=True).info(f"Dashboard started on port {api_port}.")
master = Master(
node_id_keypair,
node_id,
command_buffer,
global_events,
worker_events,
Path(os.environ["GO_BUILD_DIR"]) / "forwarder",
)
await master.run()
logger_cleanup() # pyright: ignore[reportUnreachable]
def main(logfile: Path = EXO_MASTER_LOG, verbosity: int = 1):
logger_setup(logfile, verbosity)
asyncio.run(async_main())
if __name__ == "__main__":
main()

View File

@@ -1,22 +1,22 @@
import random
from collections.abc import Mapping
from copy import deepcopy
from functools import singledispatch
from typing import Sequence
from exo.master.utils.placement_utils import (
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
)
from exo.shared.types.common import Host
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.events.commands import (
CreateInstanceCommand,
DeleteInstanceCommand,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import Instance, InstanceStatus
@@ -25,26 +25,24 @@ def random_ephemeral_port() -> int:
return random.randint(49152, 65535)
@singledispatch
def get_instance_placements(
command: CreateInstanceCommand,
def get_instance_placements_after_create(
command: CreateInstance,
topology: Topology,
current_instances: dict[InstanceId, Instance],
current_instances: Mapping[InstanceId, Instance],
*,
tb_only: bool = False,
) -> dict[InstanceId, Instance]:
available_models = [
current_instances[instance].shard_assignments.model_id
for instance in current_instances
]
if command.model_meta.model_id in available_models:
raise ValueError(f"Instance for {command.model_meta.model_id} already exists")
all_nodes = list(topology.list_nodes())
cycles = topology.get_cycles()
from loguru import logger
logger.info("finding cycles:")
cycles = topology.get_cycles_tb()
logger.info(f"{cycles=}")
# we can also always just have a node on its own
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = cycles + singleton_cycles
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size_kilobytes * 1024
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
@@ -52,25 +50,27 @@ def get_instance_placements(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
selected_cycle = None
has_thunderbolt_cycle = any(
[
topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
for cycle in smallest_cycles
]
)
if has_thunderbolt_cycle:
smallest_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if tb_only and smallest_tb_cycles == []:
raise ValueError("No cycles found with sufficient memory")
elif smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
selected_cycle = max(
smallest_cycles,
key=lambda cycle: sum(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
@@ -79,8 +79,8 @@ def get_instance_placements(
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
instance_id = command.instance_id
target_instances = deepcopy(current_instances)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
target_instances[instance_id] = Instance(
instance_id=instance_id,
instance_type=InstanceStatus.ACTIVE,
@@ -88,6 +88,9 @@ def get_instance_placements(
hosts=[
Host(
ip=host.ip,
# NOTE: this is stupid
# |
# v
# NOTE: it's fine to have non-deterministic ports here since this is in a command decision
port=random_ephemeral_port(),
)
@@ -97,13 +100,11 @@ def get_instance_placements(
return target_instances
@get_instance_placements.register
def _(
command: DeleteInstanceCommand,
topology: Topology,
current_instances: dict[InstanceId, Instance],
def get_instance_placements_after_delete(
command: DeleteInstance,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
target_instances = deepcopy(current_instances)
target_instances = dict(deepcopy(current_instances))
if command.instance_id in target_instances:
del target_instances[command.instance_id]
return target_instances

View File

@@ -4,9 +4,10 @@ from pydantic import BaseModel
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import Node
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.common import RunnerId
from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata
@@ -17,38 +18,41 @@ class NodeWithProfile(BaseModel):
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[Node]) -> TypeGuard[list[NodeWithProfile]]:
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[Node]], required_memory: int
) -> list[list[Node]]:
filtered_cycles: list[list[Node]] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not narrow_all_nodes(cycle):
continue
total_mem = sum(node.node_profile.memory.ram_available for node in cycle)
total_mem = sum(
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[Node], cycle))
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(cycles: list[list[Node]]) -> list[list[Node]]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[Node],
selected_cycle: list[NodeInfo],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
cycle_memory = sum(
node.node_profile.memory.ram_available for node in selected_cycle
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
total_layers = model_meta.n_layers
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
@@ -60,7 +64,11 @@ def get_shard_assignments(
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers * (node.node_profile.memory.ram_available / cycle_memory)
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
@@ -109,6 +117,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,

View File

@@ -1,3 +1,5 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
@@ -7,21 +9,21 @@ from exo.shared.types.profiling import (
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, Node
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> Node:
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return Node(
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile(
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
@@ -37,7 +39,7 @@ def create_node():
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection():
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
def _create_connection(
@@ -50,7 +52,6 @@ def create_connection():
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"),
send_back_multiaddr=Multiaddr(
address=f"/ip4/127.0.0.1/tcp/{send_back_port}"
),

View File

@@ -23,9 +23,8 @@ from exo.shared.constants import (
EXO_GLOBAL_EVENT_DB,
EXO_WORKER_EVENT_DB,
LIBP2P_GLOBAL_EVENTS_TOPIC,
LIBP2P_WORKER_EVENTS_TOPIC,
LIBP2P_LOCAL_EVENTS_TOPIC,
)
from exo.shared.logging import logger_test_install
from exo.shared.types.common import NodeId
# Mock forwarder script content
@@ -192,7 +191,6 @@ class TestForwardersupervisorBasic:
],
) -> None:
"""Test starting forwarder in replica mode."""
logger_test_install(test_logger)
# Set environment
os.environ.update(mock_env_vars)
@@ -216,7 +214,7 @@ class TestForwardersupervisorBasic:
# Expected replica forwarding pairs
expected_pairs = [
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}",
f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}",
f"libp2p:{LIBP2P_GLOBAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events",
]
@@ -238,7 +236,6 @@ class TestForwardersupervisorBasic:
],
) -> None:
"""Test changing role from replica to master."""
logger_test_install(test_logger)
os.environ.update(mock_env_vars)
supervisor = ForwarderSupervisor(NodeId(), mock_forwarder_script)
@@ -265,7 +262,7 @@ class TestForwardersupervisorBasic:
# Expected master forwarding pairs
master_pairs = [
f"libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events",
f"libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events",
f"sqlite:{EXO_GLOBAL_EVENT_DB}:events|libp2p:{LIBP2P_GLOBAL_EVENTS_TOPIC}",
]
@@ -285,7 +282,6 @@ class TestForwardersupervisorBasic:
],
) -> None:
"""Test that setting the same role twice doesn't restart the process."""
logger_test_install(test_logger)
os.environ.update(mock_env_vars)
supervisor = ForwarderSupervisor(NodeId(), mock_forwarder_script)
@@ -316,7 +312,6 @@ class TestForwardersupervisorBasic:
],
) -> None:
"""Test that Forwardersupervisor restarts the process if it crashes."""
logger_test_install(test_logger)
# Configure mock to exit after 1 second
mock_env_vars["MOCK_EXIT_AFTER"] = "1"
mock_env_vars["MOCK_EXIT_CODE"] = "1"
@@ -365,7 +360,6 @@ class TestForwardersupervisorBasic:
self, test_logger: logging.Logger, temp_dir: Path
) -> None:
"""Test behavior when forwarder binary doesn't exist."""
logger_test_install(test_logger)
nonexistent_path = temp_dir / "nonexistent_forwarder"
supervisor = ForwarderSupervisor(NodeId(), nonexistent_path)
@@ -381,7 +375,6 @@ class TestElectionCallbacks:
@pytest.mark.asyncio
async def test_on_became_master(self, test_logger: logging.Logger) -> None:
"""Test callback when becoming master."""
logger_test_install(test_logger)
mock_supervisor = MagicMock(spec=ForwarderSupervisor)
mock_supervisor.notify_role_change = AsyncMock()
@@ -393,7 +386,6 @@ class TestElectionCallbacks:
@pytest.mark.asyncio
async def test_on_became_replica(self, test_logger: logging.Logger) -> None:
"""Test callback when becoming replica."""
logger_test_install(test_logger)
mock_supervisor = MagicMock(spec=ForwarderSupervisor)
mock_supervisor.notify_role_change = AsyncMock()

View File

@@ -1,30 +1,29 @@
import asyncio
import tempfile
from logging import Logger
from pathlib import Path
from typing import List, Sequence
import pytest
from exo.master.main import Master
from exo.shared.db.sqlite.config import EventLogConfig
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
from exo.shared.db.sqlite.event_log_manager import EventLogManager
from exo.shared.db.config import EventLogConfig
from exo.shared.db.connector import AsyncSQLiteEventStorage
from exo.shared.db.event_log_manager import EventLogManager
from exo.shared.keypair import Keypair
from exo.shared.logging import logger_test_install
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, EventFromEventLog, Heartbeat, TaskCreated
from exo.shared.types.events._events import (
InstanceCreated,
NodePerformanceMeasured,
TopologyNodeCreated,
)
from exo.shared.types.events.commands import (
ChatCompletionCommand,
from exo.shared.types.commands import (
ChatCompletion,
Command,
CommandId,
CreateInstanceCommand,
CreateInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
NodePerformanceMeasured,
TaskCreated,
TopologyNodeCreated,
)
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import (
@@ -35,7 +34,6 @@ from exo.shared.types.profiling import (
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceStatus,
ShardAssignments,
)
@@ -43,7 +41,7 @@ from exo.shared.types.worker.shards import PartitionStrategy, PipelineShardMetad
def _create_forwarder_dummy_binary() -> Path:
path = Path(tempfile.mktemp()) / "forwarder.bin"
path = Path(tempfile.mkstemp()[1]) / "forwarder.bin"
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(b"#!/bin/sh\necho dummy forwarder && sleep 1000000\n")
@@ -53,23 +51,20 @@ def _create_forwarder_dummy_binary() -> Path:
@pytest.mark.asyncio
async def test_master():
logger = Logger(name="test_master_logger")
logger_test_install(logger)
event_log_manager = EventLogManager(EventLogConfig())
await event_log_manager.initialize()
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
await global_events.delete_all_events()
async def _get_events() -> Sequence[EventFromEventLog[Event]]:
async def _get_events() -> Sequence[IndexedEvent]:
orig_events = await global_events.get_events_since(0)
override_idx_in_log = 1
events: List[EventFromEventLog[Event]] = []
events: List[IndexedEvent] = []
for e in orig_events:
if isinstance(e.event, Heartbeat):
continue
events.append(
EventFromEventLog(
event=e.event, origin=e.origin, idx_in_log=override_idx_in_log
IndexedEvent(
event=e.event,
idx=override_idx_in_log, # origin=e.origin,
)
)
override_idx_in_log += 1
@@ -120,9 +115,8 @@ async def test_master():
await asyncio.sleep(0.001)
command_buffer.append(
CreateInstanceCommand(
CreateInstance(
command_id=CommandId(),
instance_id=InstanceId(),
model_meta=ModelMetadata(
model_id="llama-3.2-1b",
pretty_name="Llama 3.2 1B",
@@ -134,7 +128,7 @@ async def test_master():
while len(master.state.instances.keys()) == 0:
await asyncio.sleep(0.001)
command_buffer.append(
ChatCompletionCommand(
ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
@@ -150,7 +144,7 @@ async def test_master():
events = await _get_events()
print(events)
assert len(events) == 4
assert events[0].idx_in_log == 1
assert events[0].idx == 1
assert isinstance(events[0].event, TopologyNodeCreated)
assert isinstance(events[1].event, NodePerformanceMeasured)
assert isinstance(events[2].event, InstanceCreated)

View File

@@ -2,15 +2,17 @@ from typing import Callable
import pytest
from exo.master.placement import get_instance_placements, get_transition_events
from exo.shared.topology import Topology
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events._events import (
_EventType, # pyright: ignore[reportPrivateUsage]
from exo.master.placement import (
get_instance_placements_after_create,
get_transition_events,
)
from exo.shared.types.events.commands import CreateInstanceCommand
from exo.shared.types.models import ModelMetadata
from exo.shared.types.topology import Connection, Node
from exo.shared.topology import Topology
from exo.shared.types.commands import CreateInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import Instance, InstanceStatus
from exo.shared.types.worker.runners import ShardAssignments
@@ -27,7 +29,7 @@ def instance() -> Instance:
instance_id=InstanceId(),
instance_type=InstanceStatus.ACTIVE,
shard_assignments=ShardAssignments(
model_id="test-model", runner_to_shard={}, node_to_runner={}
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
),
hosts=[],
)
@@ -36,18 +38,17 @@ def instance() -> Instance:
@pytest.fixture
def model_meta() -> ModelMetadata:
return ModelMetadata(
model_id="test-model",
storage_size_kilobytes=1000,
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
)
def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand:
return CreateInstanceCommand(
def create_instance_command(model_meta: ModelMetadata) -> CreateInstance:
return CreateInstance(
command_id=CommandId(),
model_meta=model_meta,
instance_id=InstanceId(),
)
@@ -65,32 +66,33 @@ def test_get_instance_placements_create_instance(
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[Memory, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size_kilobytes = sum(
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
create_instance_command = CreateInstanceCommand(
create_instance_command = CreateInstance(
command_id=CommandId(),
model_meta=model_meta,
instance_id=InstanceId(),
)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
topology.add_node(create_node(available_memory[0] * 1024, node_id_a))
topology.add_node(create_node(available_memory[1] * 1024, node_id_b))
topology.add_node(create_node(available_memory[2] * 1024, node_id_c))
topology.add_node(create_node(Memory.from_bytes(available_memory[0]), node_id_a))
topology.add_node(create_node(Memory.from_bytes(available_memory[1]), node_id_b))
topology.add_node(create_node(Memory.from_bytes(available_memory[2]), node_id_c))
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
# act
placements = get_instance_placements(create_instance_command, topology, {})
placements = get_instance_placements_after_create(
create_instance_command, topology, {}
)
# assert
assert len(placements) == 1
@@ -117,22 +119,23 @@ def test_get_instance_placements_create_instance(
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
create_instance_command = CreateInstanceCommand(
create_instance_command = CreateInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_id="test-model",
storage_size_kilobytes=1000,
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
),
instance_id=InstanceId(),
)
placements = get_instance_placements(create_instance_command, topology, {})
placements = get_instance_placements_after_create(
create_instance_command, topology, {}
)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -144,22 +147,23 @@ def test_get_instance_placements_one_node_exact_fit(
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id))
create_instance_command = CreateInstanceCommand(
create_instance_command = CreateInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_id="test-model",
storage_size_kilobytes=1000,
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
),
instance_id=InstanceId(),
)
placements = get_instance_placements(create_instance_command, topology, {})
placements = get_instance_placements_after_create(
create_instance_command, topology, {}
)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -171,27 +175,26 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
create_instance_command = CreateInstanceCommand(
create_instance_command = CreateInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_id="test-model",
storage_size_kilobytes=1001,
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
),
instance_id=InstanceId(),
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
get_instance_placements(create_instance_command, topology, {})
get_instance_placements_after_create(create_instance_command, topology, {})
def test_get_transition_events_no_change(topology: Topology, instance: Instance):
def test_get_transition_events_no_change(instance: Instance):
# arrange
instance_id = InstanceId()
current_instances = {instance_id: instance}
@@ -204,7 +207,7 @@ def test_get_transition_events_no_change(topology: Topology, instance: Instance)
assert len(events) == 0
def test_get_transition_events_create_instance(topology: Topology, instance: Instance):
def test_get_transition_events_create_instance(instance: Instance):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {}
@@ -215,10 +218,10 @@ def test_get_transition_events_create_instance(topology: Topology, instance: Ins
# assert
assert len(events) == 1
assert events[0].event_type == _EventType.InstanceCreated
assert isinstance(events[0], InstanceCreated)
def test_get_transition_events_delete_instance(topology: Topology, instance: Instance):
def test_get_transition_events_delete_instance(instance: Instance):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
@@ -229,5 +232,5 @@ def test_get_transition_events_delete_instance(topology: Topology, instance: Ins
# assert
assert len(events) == 1
assert events[0].event_type == _EventType.InstanceDeleted
assert isinstance(events[0], InstanceDeleted)
assert events[0].instance_id == instance_id

View File

@@ -1,9 +1,8 @@
from ipaddress import IPv4Address
from typing import Callable
import pytest
from exo.master.utils.placement_utils import (
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_shard_assignments,
@@ -11,8 +10,9 @@ from exo.master.utils.placement_utils import (
)
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.topology import Connection, Node
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import Connection, NodeInfo
@pytest.fixture
@@ -23,7 +23,7 @@ def topology() -> Topology:
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
@@ -47,7 +47,7 @@ def test_filter_cycles_by_memory(
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(cycles, 1)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
@@ -57,7 +57,7 @@ def test_filter_cycles_by_memory(
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
@@ -77,7 +77,9 @@ def test_filter_cycles_by_insufficient_memory(
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(topology.get_cycles(), 2001 * 1024)
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
@@ -85,7 +87,7 @@ def test_filter_cycles_by_insufficient_memory(
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
@@ -110,7 +112,7 @@ def test_filter_multiple_cycles_by_memory(
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, 1500 * 1024)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
@@ -124,7 +126,7 @@ def test_filter_multiple_cycles_by_memory(
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
@@ -164,7 +166,7 @@ def test_get_smallest_cycles(
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
@@ -189,10 +191,10 @@ def test_get_shard_assignments(
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id="test-model",
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size_kilobytes=1000,
storage_size=Memory.from_kb(1000),
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
@@ -223,7 +225,7 @@ def test_get_shard_assignments(
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], Node],
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
@@ -250,9 +252,9 @@ def test_get_hosts_from_subgraph(
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=IPv4Address("127.0.0.1"), port=5001),
Host(ip=IPv4Address("127.0.0.1"), port=5002),
Host(ip=IPv4Address("127.0.0.1"), port=5003),
Host(ip=("127.0.0.1"), port=5001),
Host(ip=("127.0.0.1"), port=5002),
Host(ip=("127.0.0.1"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts

View File

@@ -7,7 +7,7 @@ from exo.shared.types.profiling import (
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, Node, NodeId
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -20,7 +20,6 @@ def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
@@ -30,7 +29,7 @@ def connection() -> Connection:
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile(
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile(flops_fp16=1000)
@@ -54,7 +53,7 @@ def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
node_id = NodeId()
# act
topology.add_node(Node(node_id=node_id, node_profile=node_profile))
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
data = topology.get_node_profile(node_id)
@@ -65,9 +64,11 @@ def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
@@ -82,9 +83,11 @@ def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
@@ -92,7 +95,7 @@ def test_update_node_profile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile(
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
@@ -113,9 +116,11 @@ def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
@@ -125,7 +130,6 @@ def test_update_connection_profile(
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
local_multiaddr=connection.local_multiaddr,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
@@ -142,9 +146,11 @@ def test_remove_connection_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
@@ -155,64 +161,15 @@ def test_remove_connection_still_connected(
assert topology.get_connection_profile(connection) is None
def test_remove_connection_bridge(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
"""Create a bridge scenario: master -> node_a -> node_b
and remove the bridge connection (master -> node_a)"""
# arrange
master_id = NodeId()
node_a_id = NodeId()
node_b_id = NodeId()
topology.add_node(Node(node_id=master_id, node_profile=node_profile))
topology.add_node(Node(node_id=node_a_id, node_profile=node_profile))
topology.add_node(Node(node_id=node_b_id, node_profile=node_profile))
topology.set_master_node_id(master_id)
connection_master_to_a = Connection(
local_node_id=master_id,
send_back_node_id=node_a_id,
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
connection_a_to_b = Connection(
local_node_id=node_a_id,
send_back_node_id=node_b_id,
local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1236"),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1237"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
topology.add_connection(connection_master_to_a)
topology.add_connection(connection_a_to_b)
assert len(list(topology.list_nodes())) == 3
topology.remove_connection(connection_master_to_a)
remaining_nodes = list(topology.list_nodes())
assert len(remaining_nodes) == 1
assert remaining_nodes[0].node_id == master_id
assert topology.get_node_profile(node_a_id) is None
assert topology.get_node_profile(node_b_id) is None
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
@@ -227,9 +184,11 @@ def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile))
topology.add_node(
Node(node_id=connection.send_back_node_id, node_profile=node_profile)
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
@@ -238,7 +197,7 @@ def test_list_nodes(
# assert
assert len(nodes) == 2
assert all(isinstance(node, Node) for node in nodes)
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,

View File

View File

@@ -0,0 +1,37 @@
from enum import Enum
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)

242
src/exo/routing/router.py Normal file
View File

@@ -0,0 +1,242 @@
from copy import copy
from itertools import count
from math import inf
from os import PathLike
from pathlib import Path
from typing import cast
from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
sleep_forever,
)
from anyio.abc import TaskGroup
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
from filelock import FileLock
from loguru import logger
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel
from .connection_message import ConnectionMessage
from .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic
# A significant current limitation of the TopicRouter is that it is not capable
# of preventing feedback, as it does not ask for a system id so cannot tell
# which message is coming/going to which system.
# This is currently only relevant for Election
class TopicRouter[T: CamelCaseModel]:
def __init__(
self,
topic: TypedTopic[T],
networking_sender: Sender[tuple[str, bytes]],
max_buffer_size: float = inf,
):
self.topic: TypedTopic[T] = topic
self.senders: set[Sender[T]] = set()
send, recv = channel[T]()
self.receiver: Receiver[T] = recv
self.temp_sender: Sender[T] | None = send
self.networking_sender: Sender[tuple[str, bytes]] = networking_sender
async def run(self):
logger.debug(f"Topic Router {self.topic} ready to send")
with self.receiver as items:
async for item in items:
# Check if we should send to network
if (
len(self.senders) == 0
and self.topic.publish_policy is PublishPolicy.Minimal
):
await self._send_out(item)
continue
if self.topic.publish_policy is PublishPolicy.Always:
await self._send_out(item)
# Then publish to all senders
await self.publish(item)
async def shutdown(self):
logger.debug(f"Shutting down Topic Router {self.topic}")
# Close all the things!
for sender in self.senders:
sender.close()
if self.temp_sender:
self.temp_sender.close()
self.receiver.close()
async def publish(self, item: T):
"""
Publish item T on this topic to all senders.
NB: this sends to ALL receivers, potentially including receivers held by the object doing the sending.
You should handle your own output if you hold a sender + receiver pair.
"""
to_clear: set[Sender[T]] = set()
for sender in copy(self.senders):
try:
await sender.send(item)
except (ClosedResourceError, BrokenResourceError):
to_clear.add(sender)
self.senders -= to_clear
async def publish_bytes(self, data: bytes):
await self.publish(self.topic.deserialize(data))
async def _send_out(self, item: T):
logger.trace(f"TopicRouter {self.topic.topic} sending {item}")
await self.networking_sender.send(
(str(self.topic.topic), self.topic.serialize(item))
)
class Router:
@classmethod
def create(cls, identity: Keypair) -> "Router":
return cls(handle=NetworkingHandle(identity))
def __init__(self, handle: NetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
send, recv = channel[tuple[str, bytes]]()
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._id_count = count()
self._tg: TaskGroup | None = None
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
assert self._tg is None, "Attempted to register topic after setup time"
send = self._tmp_networking_sender
if send:
self._tmp_networking_sender = None
else:
send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send)
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic))
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
router = self.topic_routers.get(topic.topic, None)
# There's gotta be a way to do this without THIS many asserts
assert router is not None
assert router.topic == topic
send: Sender[T] | None = cast(Sender[T] | None, router.temp_sender)
if send:
router.temp_sender = None
return send
try:
sender = cast(Receiver[T], router.receiver).clone_sender()
except ClosedResourceError:
sender, router.receiver = cast(
tuple[Sender[T], Receiver[CamelCaseModel]], channel[T]()
)
return sender
def receiver[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Receiver[T]:
router = self.topic_routers.get(topic.topic, None)
# There's gotta be a way to do this without THIS many asserts
assert router is not None
assert router.topic == topic
assert router.topic.model_type == topic.model_type
send, recv = channel[T]()
router.senders.add(cast(Sender[CamelCaseModel], send))
return recv
async def run(self):
logger.debug("Starting Router")
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
if not self._tg:
return
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
async def _networking_recv(self):
while True:
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self):
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
async def _networking_publish(self):
# This with/for pattern ensures this method doesn't return until after the receiver closes
# This is good for safety, but is mostly a redundant check.
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
except NoPeersSubscribedToTopicError:
logger.trace(f"Failed to send over {topic} - No peers found.")
def get_node_id_keypair(
path: str | bytes | PathLike[str] | PathLike[bytes] = EXO_NODE_ID_KEYPAIR,
) -> Keypair:
"""
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
# operate with cross-process lock to avoid race conditions
with FileLock(lock_path(path)):
with open(path, "a+b") as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -0,0 +1,141 @@
import pytest
from exo.shared.types.events import Event, TestEvent
from exo.utils.event_buffer import OrderedBuffer
def make_indexed_event(idx: int) -> tuple[int, Event]:
"""Factory function to create a unique ForwarderEvent for a given index."""
return (idx, TestEvent())
@pytest.fixture
def buffer() -> OrderedBuffer[Event]:
"""Provides a clean instance of OrderedBuffer[Event] for each test."""
return OrderedBuffer[Event]()
@pytest.mark.asyncio
async def test_initial_state(buffer: OrderedBuffer[Event]):
"""Tests that a new buffer is empty and starts at index 1."""
assert buffer.next_idx_to_release == 0
assert not buffer.store
assert buffer.drain() == []
@pytest.mark.asyncio
async def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]):
"""Tests ingesting and draining a simple, ordered sequence of events."""
events = [make_indexed_event(0), make_indexed_event(1), make_indexed_event(2)]
[buffer.ingest(*ev) for ev in events]
drained_events = buffer.drain_indexed()
assert drained_events == events
assert buffer.next_idx_to_release == 3
assert not buffer.store
@pytest.mark.asyncio
async def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]):
"""Tests that out-of-order events are buffered and drained in the correct sequence."""
event1 = make_indexed_event(0)
event2 = make_indexed_event(1)
event3 = make_indexed_event(2)
buffer.ingest(*event3)
buffer.ingest(*event1)
buffer.ingest(*event2)
drained_events = buffer.drain_indexed()
assert drained_events == [event1, event2, event3]
assert buffer.next_idx_to_release == 3
@pytest.mark.asyncio
async def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]):
"""Tests that draining stops when there is a gap in the event indices."""
event1 = make_indexed_event(0)
event3 = make_indexed_event(2)
buffer.ingest(*event1)
buffer.ingest(*event3)
drained_events = buffer.drain_indexed()
assert drained_events == [event1]
assert buffer.next_idx_to_release == 1
assert buffer.drain() == []
assert 2 in buffer.store
@pytest.mark.asyncio
async def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]):
"""Tests that once a gap is filled, the rest of the sequence is drained."""
event0 = make_indexed_event(0)
event2 = make_indexed_event(2)
buffer.ingest(*event0)
buffer.ingest(*event2)
buffer.drain()
assert buffer.next_idx_to_release == 1
event1 = make_indexed_event(1)
buffer.ingest(*event1)
drained_events = buffer.drain_indexed()
assert [e[0] for e in drained_events] == [1, 2]
assert buffer.next_idx_to_release == 3
@pytest.mark.asyncio
async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
"""Tests that if multiple events for the same index are ingested, the first one wins."""
event2_first = make_indexed_event(1)
event2_second = (1, TestEvent())
buffer.ingest(*make_indexed_event(0))
buffer.ingest(*event2_first)
buffer.ingest(*event2_second) # This duplicate should be ignored
drained = buffer.drain_indexed()
assert len(drained) == 2
assert drained[1][1].event_id == event2_first[1].event_id
assert drained[1][1].event_id != event2_second[1].event_id
@pytest.mark.asyncio
async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
"""Tests that events with an index lower than next_idx_to_release are dropped."""
buffer.ingest(*make_indexed_event(0))
buffer.ingest(*make_indexed_event(1))
buffer.drain()
assert buffer.next_idx_to_release == 2
stale_event1 = make_indexed_event(0)
stale_event2 = make_indexed_event(1)
buffer.ingest(*stale_event1)
buffer.ingest(*stale_event2)
assert not buffer.store
assert buffer.drain() == []
@pytest.mark.asyncio
async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]):
"""Tests reusing the buffer after it has been fully drained."""
buffer.ingest(*make_indexed_event(0))
buffer.ingest(*make_indexed_event(1))
buffer.drain()
assert buffer.next_idx_to_release == 2
assert not buffer.store
buffer.ingest(*make_indexed_event(4))
buffer.ingest(*make_indexed_event(2))
drained = buffer.drain_indexed()
assert [e[0] for e in drained] == [2]
assert buffer.next_idx_to_release == 3
assert 4 in buffer.store

47
src/exo/routing/topics.py Normal file
View File

@@ -0,0 +1,47 @@
from dataclasses import dataclass
from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.events import (
ForwarderEvent,
)
from exo.utils.pydantic_ext import CamelCaseModel
class PublishPolicy(str, Enum):
Never = "Never"
"""Never publish to the network - this is a local message"""
Minimal = "Minimal"
"""Only publish when there is no local receiver for this type of message"""
Always = "Always"
"""Always publish to the network"""
@dataclass # (frozen=True)
class TypedTopic[T: CamelCaseModel]:
topic: str
publish_policy: PublishPolicy
model_type: type[
T
] # This can be worked around with evil type hacking, see https://stackoverflow.com/a/71720366 - I don't think it's necessary here.
@staticmethod
def serialize(t: T) -> bytes:
return t.model_dump_json().encode("utf-8")
def deserialize(self, b: bytes) -> T:
return self.model_type.model_validate_json(b.decode("utf-8"))
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
ELECTION_MESSAGES = TypedTopic(
"election_messages", PublishPolicy.Always, ElectionMessage
)
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)

View File

@@ -1,18 +1,17 @@
from __future__ import annotations
import copy
from functools import singledispatch
from typing import Mapping
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
EventFromEventLog,
IndexedEvent,
InstanceActivated,
InstanceCreated,
InstanceDeactivated,
InstanceDeleted,
InstanceReplacedAtomically,
NodePerformanceMeasured,
RunnerDeleted,
RunnerStatusUpdated,
@@ -20,48 +19,74 @@ from exo.shared.types.events import (
TaskDeleted,
TaskFailed,
TaskStateUpdated,
TestEvent,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TopologyEdgeReplacedAtomically,
TopologyNodeCreated,
WorkerStatusUpdated,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import Connection, Node
from exo.shared.types.worker.common import NodeStatus, RunnerId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.common import RunnerId, WorkerStatus
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceStatus
from exo.shared.types.worker.runners import RunnerStatus
@singledispatch
def event_apply(event: Event, state: State) -> State:
"""Apply an event to *state*.
Events decorated with ``@no_op_event`` set ``__no_apply__ = True`` on the
class. Such events are considered *no-ops* and therefore leave the state
unchanged without requiring a dedicated handler in this dispatch table.
"""
if getattr(event, "__no_apply__", False):
return state
raise RuntimeError(f"no handler registered for event type {type(event).__name__}")
"""Apply an event to state."""
match event:
case TestEvent() | ChunkGenerated():
return state
case InstanceActivated():
return apply_instance_activated(event, state)
case InstanceCreated():
return apply_instance_created(event, state)
case InstanceDeactivated():
return apply_instance_deactivated(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
return apply_runner_status_updated(event, state)
case TaskCreated():
return apply_task_created(event, state)
case TaskDeleted():
return apply_task_deleted(event, state)
case TaskFailed():
return apply_task_failed(event, state)
case TaskStateUpdated():
return apply_task_state_updated(event, state)
case WorkerStatusUpdated():
return apply_worker_status_updated(event, state)
case TopologyNodeCreated():
return apply_topology_node_created(event, state)
case TopologyEdgeCreated():
return apply_topology_edge_created(event, state)
case TopologyEdgeDeleted():
return apply_topology_edge_deleted(event, state)
def apply(state: State, event: EventFromEventLog[Event]) -> State:
def apply(state: State, event: IndexedEvent) -> State:
# Just to test that events are only applied in correct order
if state.last_event_applied_idx != event.idx - 1:
logger.warning(
f"Expected event {state.last_event_applied_idx + 1} but received {event.idx}"
)
assert state.last_event_applied_idx == event.idx - 1
new_state: State = event_apply(event.event, state)
return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log})
return new_state.model_copy(update={"last_event_applied_idx": event.idx})
@event_apply.register(TaskCreated)
def apply_task_created(event: TaskCreated, state: State) -> State:
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register(TaskDeleted)
def apply_task_deleted(event: TaskDeleted, state: State) -> State:
new_tasks: Mapping[TaskId, Task] = {
tid: task for tid, task in state.tasks.items() if tid != event.task_id
@@ -69,7 +94,6 @@ def apply_task_deleted(event: TaskDeleted, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register(TaskStateUpdated)
def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
if event.task_id not in state.tasks:
return state
@@ -86,7 +110,6 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register(TaskFailed)
def apply_task_failed(event: TaskFailed, state: State) -> State:
if event.task_id not in state.tasks:
return state
@@ -98,7 +121,6 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register(InstanceCreated)
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = event.instance
new_instances: Mapping[InstanceId, Instance] = {
@@ -108,13 +130,12 @@ def apply_instance_created(event: InstanceCreated, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceActivated)
def apply_instance_activated(event: InstanceActivated, state: State) -> State:
if event.instance_id not in state.instances:
return state
updated_instance = state.instances[event.instance_id].model_copy(
update={"type": InstanceStatus.ACTIVE}
update={"instance_type": InstanceStatus.ACTIVE}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
@@ -123,13 +144,12 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceDeactivated)
def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> State:
if event.instance_id not in state.instances:
return state
updated_instance = state.instances[event.instance_id].model_copy(
update={"type": InstanceStatus.INACTIVE}
update={"instance_type": InstanceStatus.INACTIVE}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
@@ -138,7 +158,6 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceDeleted)
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
new_instances: Mapping[InstanceId, Instance] = {
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
@@ -146,19 +165,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
@event_apply.register(InstanceReplacedAtomically)
def apply_instance_replaced_atomically(
event: InstanceReplacedAtomically, state: State
) -> State:
new_instances = dict(state.instances)
if event.instance_to_replace in new_instances:
del new_instances[event.instance_to_replace]
if event.new_instance_id in state.instances:
new_instances[event.new_instance_id] = state.instances[event.new_instance_id]
return state.model_copy(update={"instances": new_instances})
@event_apply.register(RunnerStatusUpdated)
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,
@@ -167,7 +173,6 @@ def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> Sta
return state.model_copy(update={"runners": new_runners})
@event_apply.register(RunnerDeleted)
def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
rid: rs for rid, rs in state.runners.items() if rid != event.runner_id
@@ -175,7 +180,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
return state.model_copy(update={"runners": new_runners})
@event_apply.register(NodePerformanceMeasured)
# TODO: This whole function needs fixing
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
@@ -187,58 +192,39 @@ def apply_node_performance_measured(
topology = copy.copy(state.topology)
if not topology.contains_node(event.node_id):
# TODO: figure out why this is happening in the first place
topology.add_node(Node(node_id=event.node_id))
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(update={"topology": topology})
@event_apply.register(WorkerStatusUpdated)
def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> State:
new_node_status: Mapping[NodeId, NodeStatus] = {
new_node_status: Mapping[NodeId, WorkerStatus] = {
**state.node_status,
event.node_id: event.node_state,
}
return state.model_copy(update={"node_status": new_node_status})
@event_apply.register(TopologyNodeCreated)
def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(Node(node_id=event.node_id))
if event.role == "MASTER":
topology.set_master_node_id(event.node_id)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
@event_apply.register(TopologyEdgeCreated)
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
@event_apply.register(TopologyEdgeReplacedAtomically)
def apply_topology_edge_replaced_atomically(
event: TopologyEdgeReplacedAtomically, state: State
) -> State:
topology = copy.copy(state.topology)
topology.update_connection_profile(event.edge)
return state.model_copy(update={"topology": topology})
@event_apply.register(TopologyEdgeDeleted)
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
opposite_edge = Connection(
local_node_id=event.edge.send_back_node_id,
send_back_node_id=event.edge.local_node_id,
local_multiaddr=event.edge.send_back_multiaddr,
send_back_multiaddr=event.edge.local_multiaddr,
)
if not topology.contains_connection(opposite_edge):
return state.model_copy(update={"topology": topology})
topology.remove_connection(opposite_edge)
if not topology.contains_connection(event.edge) and topology.contains_connection(
event.edge.reverse()
):
topology.remove_connection(event.edge.reverse())
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -1,3 +0,0 @@
from .apply import apply
__all__ = ["apply"]

View File

@@ -21,8 +21,10 @@ EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
EXO_IPC_DIR = EXO_HOME / "ipc"
# libp2p topics for event forwarding
LIBP2P_WORKER_EVENTS_TOPIC = "worker_events"
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip.
LB_TFLOPS = 2.3

View File

@@ -1,5 +0,0 @@
"""Database implementations for event storage."""
from .sqlite import AsyncSQLiteEventStorage, EventStorageProtocol
__all__ = ["AsyncSQLiteEventStorage", "EventStorageProtocol"]

View File

@@ -0,0 +1,19 @@
from pathlib import Path
from pydantic import BaseModel
from exo.shared.constants import EXO_GLOBAL_EVENT_DB
class EventLogConfig(BaseModel):
"""Configuration for the event log system"""
# Batch processing settings
batch_size: int = 100
batch_timeout_ms: int = 100
debounce_ms: int = 10
max_age_ms: int = 100
def get_db_path(self) -> Path:
"""Get the full path for a specific event log type"""
return EXO_GLOBAL_EVENT_DB

View File

@@ -8,13 +8,13 @@ from pathlib import Path
from typing import Any, cast
from loguru import logger
from pydantic import TypeAdapter
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, create_async_engine
from exo.shared.types.events import Event, EventParser, NodeId
from exo.shared.types.events._events import Heartbeat
from exo.shared.types.events.components import EventFromEventLog
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, IndexedEvent, event_tag
from .types import StoredEvent
@@ -73,9 +73,7 @@ class AsyncSQLiteEventStorage:
for event in events:
await self._write_queue.put((event, origin))
async def get_events_since(
self, last_idx: int, ignore_no_op_events: bool = False
) -> Sequence[EventFromEventLog[Event]]:
async def get_events_since(self, last_idx: int) -> Sequence[IndexedEvent]:
"""Retrieve events after a specific index."""
if self._closed:
raise RuntimeError("Storage is closed")
@@ -92,10 +90,10 @@ class AsyncSQLiteEventStorage:
)
rows = result.fetchall()
events: list[EventFromEventLog[Event]] = []
events: list[IndexedEvent] = []
for row in rows:
rowid: int = cast(int, row[0])
origin: str = cast(str, row[1])
# origin: str = cast(str, row[1])
# Parse JSON string to dict
raw_event_data = row[2] # type: ignore[reportAny] - SQLAlchemy result is Any
if isinstance(raw_event_data, str):
@@ -104,14 +102,12 @@ class AsyncSQLiteEventStorage:
)
else:
event_data = cast(dict[str, Any], raw_event_data)
event = EventParser.validate_python(event_data)
if ignore_no_op_events and event.__no_apply__:
continue
event: Event = TypeAdapter(Event).validate_python(event_data) # type: ignore
events.append(
EventFromEventLog(
event=event,
origin=NodeId(origin),
idx_in_log=rowid, # rowid becomes idx_in_log
IndexedEvent(
event=event, # type: ignore
# origin=NodeId(origin),
idx=rowid, # rowid becomes idx_in_log
)
)
@@ -325,7 +321,7 @@ class AsyncSQLiteEventStorage:
for event, origin in batch:
stored_event = StoredEvent(
origin=origin,
event_type=event.event_type,
event_type=event_tag(event),
event_id=str(event.event_id),
event_data=event.model_dump(
mode="json"
@@ -334,8 +330,7 @@ class AsyncSQLiteEventStorage:
session.add(stored_event)
await session.commit()
if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0:
logger.debug(f"Committed batch of {len(batch)} events")
logger.debug(f"Committed batch of {len(batch)} events")
except OperationalError as e:
if "database is locked" in str(e):
@@ -393,7 +388,7 @@ class AsyncSQLiteEventStorage:
for event, origin in batch:
stored_event = StoredEvent(
origin=origin,
event_type=event.event_type,
event_type=event_tag(event),
event_id=str(event.event_id),
event_data=event.model_dump(mode="json"),
)
@@ -401,10 +396,9 @@ class AsyncSQLiteEventStorage:
await session.commit()
if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0:
logger.debug(
f"Committed batch of {len(batch)} events after {retry_count} retries"
)
logger.debug(
f"Committed batch of {len(batch)} events after {retry_count} retries"
)
return
except OperationalError as e:

View File

@@ -0,0 +1,110 @@
import asyncio
from typing import cast
from loguru import logger
from sqlalchemy.exc import OperationalError
from exo.shared.constants import EXO_HOME
from exo.shared.db.config import EventLogConfig
from exo.shared.db.connector import AsyncSQLiteEventStorage
from exo.utils.fs import ensure_directory_exists
class EventLogManager:
"""
Manages both worker and global event log connectors.
Used by both master and worker processes with different access patterns:
- Worker: writes to worker_events, tails global_events
- Master (elected): writes to global_events, tails global_events
- Master (replica): writes to worker_events, tails global_events
"""
def __init__(self, config: EventLogConfig):
self._config = config
self._connector: AsyncSQLiteEventStorage | None = None
# Ensure base directory exists
ensure_directory_exists(EXO_HOME)
# TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor.
async def initialize(self, max_retries: int = 3) -> None:
"""Initialize both connectors with retry logic - call this during startup"""
# Both master and worker need both connectors
retry_count: int = 0
last_error: Exception | None = None
while retry_count < max_retries:
try:
await self.get_connector()
break
except OperationalError as e:
last_error = e
if "database is locked" in str(e) and retry_count < max_retries - 1:
retry_count += 1
delay = cast(float, 0.5 * (2**retry_count))
logger.warning(
f"Database locked while initializing db, retry {retry_count}/{max_retries} after {delay}s"
)
await asyncio.sleep(delay)
else:
logger.opt(exception=e).error(
f"Failed to initialize db after {retry_count + 1} attempts"
)
raise RuntimeError(
f"Could not initialize db after {retry_count + 1} attempts"
) from e
except Exception as e:
logger.opt(exception=e).error("Unexpected error initializing db")
raise
if retry_count >= max_retries and last_error:
raise RuntimeError(
f"Could not initialize db after {max_retries} attempts"
) from last_error
logger.bind(user_facing=True).info("Initialized all event log connectors")
async def get_connector(self) -> AsyncSQLiteEventStorage:
"""Get or create a connector for the specified log type"""
if not self._connector:
db_path = self._config.get_db_path()
try:
connector = AsyncSQLiteEventStorage(
db_path=db_path,
batch_size=self._config.batch_size,
batch_timeout_ms=self._config.batch_timeout_ms,
debounce_ms=self._config.debounce_ms,
max_age_ms=self._config.max_age_ms,
)
# Start the connector (creates tables if needed)
await connector.start()
self._connector = connector
logger.bind(user_facing=True).info(
f"Initialized db connector at {db_path}"
)
except Exception as e:
logger.bind(user_facing=True).opt(exception=e).error(
"Failed to create db connector"
)
raise
return self._connector
@property
def events(self) -> AsyncSQLiteEventStorage:
"""Access event log (must call initialize() first)"""
if not self._connector:
raise RuntimeError(
"Event log manager not initialized. Call initialize() first."
)
return self._connector
async def close(self) -> None:
"""Close all open connectors"""
assert self._connector is not None
await self._connector.close()
logger.bind(user_facing=True).info("Closed db connector")
self._connector = None

View File

@@ -1,15 +0,0 @@
"""SQLite event storage implementation."""
from .config import EventLogConfig, EventLogType
from .connector import AsyncSQLiteEventStorage
from .event_log_manager import EventLogManager
from .types import EventStorageProtocol, StoredEvent
__all__ = [
"AsyncSQLiteEventStorage",
"EventLogConfig",
"EventLogManager",
"EventLogType",
"EventStorageProtocol",
"StoredEvent",
]

View File

@@ -1,32 +0,0 @@
from enum import Enum
from pathlib import Path
from pydantic import BaseModel
from exo.shared.constants import EXO_GLOBAL_EVENT_DB, EXO_WORKER_EVENT_DB
class EventLogType(str, Enum):
"""Types of event logs in the system"""
WORKER_EVENTS = "worker_events"
GLOBAL_EVENTS = "global_events"
class EventLogConfig(BaseModel):
"""Configuration for the event log system"""
# Batch processing settings
batch_size: int = 100
batch_timeout_ms: int = 100
debounce_ms: int = 10
max_age_ms: int = 100
def get_db_path(self, log_type: EventLogType) -> Path:
"""Get the full path for a specific event log type"""
if log_type == EventLogType.WORKER_EVENTS:
return EXO_WORKER_EVENT_DB
elif log_type == EventLogType.GLOBAL_EVENTS:
return EXO_GLOBAL_EVENT_DB
else:
raise ValueError(f"Unknown log type: {log_type}")

View File

@@ -1,122 +0,0 @@
import asyncio
from typing import Dict, Optional, cast
from loguru import logger
from sqlalchemy.exc import OperationalError
from exo.shared.constants import EXO_HOME
from exo.shared.db.sqlite.config import EventLogConfig, EventLogType
from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage
from exo.shared.utils.fs import ensure_directory_exists
class EventLogManager:
"""
Manages both worker and global event log connectors.
Used by both master and worker processes with different access patterns:
- Worker: writes to worker_events, tails global_events
- Master (elected): writes to global_events, tails global_events
- Master (replica): writes to worker_events, tails global_events
"""
def __init__(self, config: EventLogConfig):
self._config = config
self._connectors: Dict[EventLogType, AsyncSQLiteEventStorage] = {}
# Ensure base directory exists
ensure_directory_exists(EXO_HOME)
# TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor.
async def initialize(self, max_retries: int = 3) -> None:
"""Initialize both connectors with retry logic - call this during startup"""
# Both master and worker need both connectors
for log_type in [EventLogType.WORKER_EVENTS, EventLogType.GLOBAL_EVENTS]:
retry_count: int = 0
last_error: Optional[Exception] = None
while retry_count < max_retries:
try:
await self.get_connector(log_type)
break
except OperationalError as e:
last_error = e
if "database is locked" in str(e) and retry_count < max_retries - 1:
retry_count += 1
delay = cast(float, 0.5 * (2**retry_count))
logger.warning(
f"Database locked while initializing {log_type.value}, retry {retry_count}/{max_retries} after {delay}s"
)
await asyncio.sleep(delay)
else:
logger.opt(exception=e).error(
f"Failed to initialize {log_type.value} after {retry_count + 1} attempts"
)
raise RuntimeError(
f"Could not initialize {log_type.value} database after {retry_count + 1} attempts"
) from e
except Exception as e:
logger.opt(exception=e).error(
f"Unexpected error initializing {log_type.value}"
)
raise
if retry_count >= max_retries and last_error:
raise RuntimeError(
f"Could not initialize {log_type.value} database after {max_retries} attempts"
) from last_error
logger.bind(user_facing=True).info("Initialized all event log connectors")
async def get_connector(self, log_type: EventLogType) -> AsyncSQLiteEventStorage:
"""Get or create a connector for the specified log type"""
if log_type not in self._connectors:
db_path = self._config.get_db_path(log_type)
try:
connector = AsyncSQLiteEventStorage(
db_path=db_path,
batch_size=self._config.batch_size,
batch_timeout_ms=self._config.batch_timeout_ms,
debounce_ms=self._config.debounce_ms,
max_age_ms=self._config.max_age_ms,
)
# Start the connector (creates tables if needed)
await connector.start()
self._connectors[log_type] = connector
logger.bind(user_facing=True).info(
f"Initialized {log_type.value} connector at {db_path}"
)
except Exception as e:
logger.bind(user_facing=True).opt(exception=e).error(
f"Failed to create {log_type.value} connector"
)
raise
return self._connectors[log_type]
@property
def worker_events(self) -> AsyncSQLiteEventStorage:
"""Access worker events log (must call initialize() first)"""
if EventLogType.WORKER_EVENTS not in self._connectors:
raise RuntimeError(
"Event log manager not initialized. Call initialize() first."
)
return self._connectors[EventLogType.WORKER_EVENTS]
@property
def global_events(self) -> AsyncSQLiteEventStorage:
"""Access global events log (must call initialize() first)"""
if EventLogType.GLOBAL_EVENTS not in self._connectors:
raise RuntimeError(
"Event log manager not initialized. Call initialize() first."
)
return self._connectors[EventLogType.GLOBAL_EVENTS]
async def close_all(self) -> None:
"""Close all open connectors"""
for log_type, connector in self._connectors.items():
await connector.close()
logger.bind(user_facing=True).info(f"Closed {log_type.value} connector")
self._connectors.clear()

View File

@@ -1,13 +1,9 @@
from datetime import datetime, timezone
from typing import Any, Protocol, Sequence
from typing import Any
from sqlalchemy import DateTime, Index
from sqlmodel import JSON, Column, Field, SQLModel
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event
from exo.shared.types.events.components import EventFromEventLog
class StoredEvent(SQLModel, table=True):
"""SQLite representation of an event in the event log.
@@ -29,28 +25,3 @@ class StoredEvent(SQLModel, table=True):
)
__table_args__ = (Index("idx_events_origin_created", "origin", "created_at"),)
class EventStorageProtocol(Protocol):
"""Protocol for event storage implementations."""
async def append_events(self, events: Sequence[Event], origin: NodeId) -> None:
"""Append events to the log (fire-and-forget).
Events are queued for batched writing and assigned idx_in_log
when committed to storage.
"""
...
async def get_events_since(
self, last_idx: int
) -> Sequence[EventFromEventLog[Event]]:
"""Retrieve events after a specific index.
Returns events in idx_in_log order.
"""
...
async def close(self) -> None:
"""Close the storage connection and cleanup resources."""
...

183
src/exo/shared/election.py Normal file
View File

@@ -0,0 +1,183 @@
from typing import Self
import anyio
from anyio import (
CancelScope,
Event,
create_task_group,
get_cancelled_exc_class,
)
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage
from exo.shared.types.common import NodeId
from exo.utils.channels import Receiver, Sender
from exo.utils.pydantic_ext import CamelCaseModel
ELECTION_TIMEOUT = 3.0
class ElectionMessage(CamelCaseModel):
clock: int
seniority: int
node_id: NodeId
# Could eventually include a list of neighbour nodes for centrality
def __lt__(self, other: Self):
if self.seniority != other.seniority:
return self.seniority < other.seniority
else:
return self.node_id < other.node_id
class ElectionResult(CamelCaseModel):
node_id: NodeId
is_new_master: bool
historic_messages: list[ConnectionMessage]
class Election:
def __init__(
self,
node_id: NodeId,
election_message_receiver: Receiver[ElectionMessage],
election_message_sender: Sender[ElectionMessage],
election_result_sender: Sender[ElectionResult],
connection_message_receiver: Receiver[ConnectionMessage],
*,
is_candidate: bool = True,
seniority: int = 0,
):
# If we aren't a candidate, simply don't increment seniority.
# For reference: This node can be elected master if all nodes are not master candidates
# Any master candidate will automatically win out over this node.
self.seniority = seniority if is_candidate else -1
self.clock = 0
self.node_id = node_id
# Every node spawns as master
self.master_node_id: NodeId = node_id
self._em_sender = election_message_sender
self._em_receiver = election_message_receiver
self._er_sender = election_result_sender
self._cm_receiver = connection_message_receiver
# Campaign state
self._candidates: list[ElectionMessage] = []
self._campaign_cancel_scope: CancelScope | None = None
self._campaign_done: Event | None = None
self._tg: TaskGroup | None = None
self._connection_messages: list[ConnectionMessage] = []
async def run(self):
logger.info("Starting Election")
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
await self._campaign(None)
if self._campaign_cancel_scope is not None:
self._campaign_cancel_scope.cancel()
# Only exit once the latest campaign has finished
if self._campaign_done is not None:
await self._campaign_done.wait()
async def elect(self, node_id: NodeId) -> None:
is_new_master = node_id != self.master_node_id
self.master_node_id = node_id
await self._er_sender.send(
ElectionResult(
node_id=node_id,
is_new_master=is_new_master,
historic_messages=self._connection_messages,
)
)
async def shutdown(self) -> None:
if not self._tg:
logger.warning(
"Attempted to shutdown election service that was not running"
)
return
self._tg.cancel_scope.cancel()
async def _election_receiver(self) -> None:
with self._em_receiver as election_messages:
async for message in election_messages:
if message.node_id == self.node_id:
# Drop messages from us (See exo.routing.router)
continue
# If a new round is starting, we participate
if message.clock > self.clock:
self.clock = message.clock
await self._campaign(message)
continue
# Dismiss old messages
if message.clock < self.clock:
continue
logger.debug(f"Election added candidate {message}")
# Now we are processing this rounds messages - including the message that triggered this round.
self._candidates.append(message)
async def _connection_receiver(self) -> None:
with self._cm_receiver as connection_messages:
async for msg in connection_messages:
# These messages are strictly peer to peer
self.clock += 1
await self._campaign(None)
self._connection_messages.append(msg)
async def _campaign(self, initial_message: ElectionMessage | None) -> None:
# Kill the old campaign
if self._campaign_cancel_scope:
self._campaign_cancel_scope.cancel()
if self._campaign_done:
await self._campaign_done.wait()
candidates: list[ElectionMessage] = []
if initial_message:
candidates.append(initial_message)
self._candidates = candidates
done = Event()
self._campaign_done = done
assert self._tg is not None, (
"Election campaign started before election service initialized"
)
# Spin off a new campaign
self._tg.start_soon(self._complete_campaign, self.clock, candidates, done)
async def _complete_campaign(
self, clock: int, candidates: list[ElectionMessage], done: Event
) -> None:
scope = CancelScope()
try:
with scope:
self._campaign_cancel_scope = scope
logger.info(f"Election {clock} started")
candidates.append(self._election_status(clock))
await self._em_sender.send(self._election_status(clock))
await anyio.sleep(ELECTION_TIMEOUT)
# Election finished!
candidates = sorted(candidates)
logger.debug(f"Election queue {candidates}")
elected = candidates[-1]
logger.info("Election finished")
if self.node_id == elected.node_id and self.seniority >= 0:
self.seniority = max(self.seniority, len(candidates))
await self.elect(elected.node_id)
except get_cancelled_exc_class():
logger.info("Election cancelled")
finally:
if self._campaign_cancel_scope is scope:
self._campaign_cancel_scope = None
done.set()
def _election_status(self, clock: int | None = None) -> ElectionMessage:
c = self.clock if clock is None else clock
return ElectionMessage(clock=c, seniority=self.seniority, node_id=self.node_id)

View File

@@ -1,28 +0,0 @@
import logging
import os
from typing import TypeVar
from pydantic import BaseModel, ConfigDict, ValidationError
env_model_config = ConfigDict(
strict=True,
frozen=True,
extra="forbid",
)
class BaseEnv(BaseModel):
model_config = env_model_config
EnvSchema = TypeVar("EnvSchema", bound=BaseEnv)
def get_validated_env(
environment_schema: type[EnvSchema], logger: logging.Logger
) -> EnvSchema:
try:
return environment_schema.model_validate(os.environ, strict=True)
except ValidationError as e:
logger.error("Environment Variables Validation Failed: %s", e)
raise e

View File

@@ -18,6 +18,7 @@ class AsyncConnection[SendT, RecvT]:
- await send(...) from asyncio code
- send_sync(...) from executor/background threads
"""
def __init__(self, conn: Connection):
self._conn = conn
self._send_lock = threading.Lock()
@@ -44,7 +45,7 @@ class AsyncConnection[SendT, RecvT]:
def _recv_blocking(self) -> RecvT:
# Not strictly needed in your parent, but safe if misused elsewhere
with self._recv_lock:
return self._conn.recv() # type: ignore[no-any-return]
return self._conn.recv() # type: ignore[no-any-return]
async def poll(self, timeout: float | None = None) -> bool:
return await asyncio.to_thread(self._conn.poll, timeout)
@@ -52,12 +53,15 @@ class AsyncConnection[SendT, RecvT]:
def close(self) -> None:
self._conn.close()
_conn: Optional[AsyncConnection[RunnerResponse, RunnerMessage]] = None
def set_conn(c: AsyncConnection[RunnerResponse, RunnerMessage]) -> None:
global _conn
_conn = c
def get_conn() -> AsyncConnection[RunnerResponse, RunnerMessage]:
if _conn is None:
raise RuntimeError("Global conn has not been set yet")

View File

@@ -12,7 +12,7 @@ import time
from enum import Enum
from typing import Optional
from exo.shared.utils.fs import StrPath, ensure_parent_directory_exists
from exo.utils.fs import StrPath, ensure_parent_directory_exists
# open in read-write mode, creates file if it doesn't exist already,
# closes this file descriptor in any children processes (prevents FD leaking),

View File

@@ -33,7 +33,7 @@ from typing import Callable
from cobs import cobs # pyright: ignore[reportMissingTypeStubs]
from pytest import LogCaptureFixture
from exo.shared.utils.fs import (
from exo.utils.fs import (
StrPath,
delete_if_exists,
ensure_parent_directory_exists,

View File

@@ -1,249 +0,0 @@
from __future__ import annotations
import hashlib
import logging
import os
from pathlib import Path
from typing import final
import base58
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from filelock import FileLock
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@final
class PeerId:
"""
A libp2p peer identifier derived from a cryptographic public key.
Compatible with py-libp2p's PeerID interface.
"""
def __init__(self, peer_id_bytes: bytes) -> None:
self._bytes = peer_id_bytes
@staticmethod
def from_bytes(data: bytes) -> "PeerId":
"""Create PeerId from raw bytes."""
return PeerId(data)
@staticmethod
def from_public_key(public_key_bytes: bytes) -> "PeerId":
"""Create PeerId from a public key by hashing it."""
# For Ed25519 keys, libp2p uses the identity hash (no hashing) for keys <= 42 bytes
# Since Ed25519 public keys are 32 bytes, we use identity hash
if len(public_key_bytes) <= 42:
return PeerId(public_key_bytes)
else:
# For larger keys, use SHA-256
hash_digest = hashlib.sha256(public_key_bytes).digest()
return PeerId(hash_digest)
def to_bytes(self) -> bytes:
"""Return the raw bytes of this PeerId."""
return self._bytes
def to_base58(self) -> str:
"""Return the base58-encoded string representation."""
return base58.b58encode(self._bytes).decode("ascii")
def __str__(self) -> str:
"""Return the base58-encoded string representation."""
return self.to_base58()
def __repr__(self) -> str:
"""Return debug representation."""
return f"PeerId('{self.to_base58()}')"
def __eq__(self, other: object) -> bool:
"""Check equality with another PeerId."""
if not isinstance(other, PeerId):
return False
return self._bytes == other._bytes
def __hash__(self) -> int:
"""Make PeerId hashable."""
return hash(self._bytes)
@final
class Keypair:
"""
A py-libp2p compatible keypair implementation.
Provides the same interface as py-libp2p's KeyPair.
"""
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
self._private_key = private_key
self._public_key = private_key.public_key()
@staticmethod
def generate_ed25519() -> "Keypair":
"""Generate a new Ed25519 keypair."""
private_key = ed25519.Ed25519PrivateKey.generate()
return Keypair(private_key)
@staticmethod
def from_protobuf_encoding(data: bytes) -> "Keypair":
"""
Deserialize a keypair from libp2p protobuf encoding.
Compatible with py-libp2p's serialization format.
"""
if len(data) < 2:
raise ValueError("Invalid protobuf data: too short")
# Simple protobuf parsing for our specific use case
# We expect: field 1 (type) as varint, field 2 (data) as bytes
offset = 0
# Parse type field (field tag 1, wire type 0 = varint)
if data[offset] != 0x08: # field 1, varint
raise ValueError("Expected type field")
offset += 1
key_type = data[offset]
offset += 1
if key_type != 1: # Ed25519
raise ValueError(f"Unsupported key type: {key_type}")
# Parse data field (field tag 2, wire type 2 = length-delimited)
if offset >= len(data) or data[offset] != 0x12: # field 2, bytes
raise ValueError("Expected data field")
offset += 1
# Parse length
data_length = data[offset]
offset += 1
if data_length not in (32, 64):
raise ValueError(f"Invalid Ed25519 private key length: {data_length}")
if offset + data_length > len(data):
raise ValueError("Truncated private key data")
key_data = data[offset : offset + data_length]
try:
if data_length == 64:
# libp2p format: 32 bytes private key seed + 32 bytes public key
private_key_seed = key_data[:32]
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(
private_key_seed
)
else:
# Raw 32-byte private key
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(key_data)
return Keypair(private_key)
except Exception as e:
raise ValueError(f"Invalid Ed25519 private key: {e}") from e
def to_protobuf_encoding(self) -> bytes:
"""
Serialize this keypair to libp2p protobuf encoding.
Compatible with py-libp2p's serialization format.
"""
private_key_bytes = self._private_key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
public_key_bytes = self._public_key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
# libp2p Ed25519 format: private key seed (32) + public key (32)
combined_key_data = private_key_bytes + public_key_bytes
# Build protobuf manually for our simple case
# Field 1 (type): tag=0x08, value=1 (Ed25519)
# Field 2 (data): tag=0x12, length=64, data=combined_key_data
result = bytearray()
result.extend([0x08, 0x01]) # field 1: type = 1 (Ed25519)
result.extend([0x12, 0x40]) # field 2: length = 64 bytes
result.extend(combined_key_data)
return bytes(result)
def to_peer_id(self) -> PeerId:
"""Generate a PeerId from this keypair's public key."""
public_key_bytes = self._public_key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
return PeerId.from_public_key(public_key_bytes)
def sign(self, data: bytes) -> bytes:
"""Sign data with this keypair's private key."""
return self._private_key.sign(data)
def verify(self, data: bytes, signature: bytes) -> bool:
"""Verify a signature against data using this keypair's public key."""
try:
self._public_key.verify(signature, data)
return True
except Exception:
return False
@property
def public_key_bytes(self) -> bytes:
"""Get the raw public key bytes."""
return self._public_key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
@property
def private_key_bytes(self) -> bytes:
"""Get the raw private key bytes."""
return self._private_key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
# py-libp2p compatibility properties
@property
def private_key(self) -> ed25519.Ed25519PrivateKey:
"""Access to the underlying private key for py-libp2p compatibility."""
return self._private_key
@property
def public_key(self) -> ed25519.Ed25519PublicKey:
"""Access to the underlying public key for py-libp2p compatibility."""
return self._public_key
def get_node_id_keypair(
path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR,
) -> Keypair:
"""
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
def lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
# operate with cross-process lock to avoid race conditions
with FileLock(lock_path(path)):
with open(path, "a+b") as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logging.warning(
f"Encountered error when trying to get keypair: {e}"
)
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -1,32 +1,13 @@
from __future__ import annotations
import sys
from logging import Logger
from pathlib import Path
import loguru
from loguru import logger
from exo.shared.constants import EXO_TEST_LOG
def is_user_facing(record: loguru.Record) -> bool:
return ("user_facing" in record["extra"]) and record["extra"]["user_facing"]
def logger_setup(log_file: Path, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logger.remove()
if verbosity == 0:
_ = logger.add( # type: ignore
sys.__stderr__, # type: ignore
format="[ {time:hh:mm:ss.SSSSA} | <level>{level: <8}</level>] <level>{message}</level>",
level="INFO",
colorize=True,
enqueue=True,
filter=is_user_facing,
)
elif verbosity == 1:
_ = logger.add( # type: ignore
sys.__stderr__, # type: ignore
format="[ {time:hh:mm:ss.SSSSA} | <level>{level: <8}</level>] <level>{message}</level>",
@@ -40,11 +21,12 @@ def logger_setup(log_file: Path, verbosity: int = 0):
format="[ {time:HH:mm:ss.SSS} | <level>{level: <8}</level> | {name}:{function}:{line} ] <level>{message}</level>",
level="DEBUG",
colorize=True,
enqueue=True,
)
_ = logger.add(
log_file,
format="[ {time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} ] {message}",
level="DEBUG",
level="INFO",
enqueue=True,
)
@@ -52,10 +34,3 @@ def logger_setup(log_file: Path, verbosity: int = 0):
def logger_cleanup():
"""Flush all queues before shutting down so any in-flight logs are written to disk"""
logger.complete()
def logger_test_install(py_logger: Logger):
"""Installs a default python logger into the Loguru environment by capturing all its handlers - intended to be used for pytest compatibility, not within the main codebase"""
logger_setup(EXO_TEST_LOG, 3)
for handler in py_logger.handlers:
logger.add(handler)

View File

@@ -1,11 +1,11 @@
from typing import List
from pydantic import BaseModel
from exo.shared.types.models import ModelMetadata
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
class ModelCard(BaseModel):
class ModelCard(CamelCaseModel):
short_id: str
model_id: str
name: str
@@ -23,9 +23,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-V3-0324-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
pretty_name="DeepSeek V3 0324 (4-bit)",
storage_size_kilobytes=409706307,
storage_size=Memory.from_kb(409706307),
n_layers=61,
),
),
@@ -36,9 +36,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-v3-0324-8bit",
model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
pretty_name="DeepSeek V3 0324 (8-bit)",
storage_size_kilobytes=754706307,
storage_size=Memory.from_kb(754706307),
n_layers=61,
),
),
@@ -49,9 +49,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-V3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size_kilobytes=754706307,
storage_size=Memory.from_kb(754706307),
n_layers=61,
),
),
@@ -62,9 +62,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-V3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size_kilobytes=754706307 // 2, # TODO !!!!!
storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
n_layers=61,
),
),
@@ -76,9 +76,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-R1-0528-4bit",
model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
pretty_name="DeepSeek R1 671B (4-bit)",
storage_size_kilobytes=409706307,
storage_size=Memory.from_kb(409706307),
n_layers=61,
),
),
@@ -89,9 +89,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-R1-0528-8bit",
model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
pretty_name="DeepSeek R1 671B (8-bit)",
storage_size_kilobytes=754998771712 // 1024,
storage_size=Memory.from_bytes(754998771712),
n_layers=61,
),
),
@@ -103,9 +103,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B",
storage_size_kilobytes=4411528,
storage_size=Memory.from_kb(4411528),
n_layers=32,
),
),
@@ -116,9 +116,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B",
storage_size_kilobytes=38758160,
storage_size=Memory.from_kb(38758160),
n_layers=80,
),
),
@@ -130,9 +130,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B",
storage_size_kilobytes=678948,
storage_size=Memory.from_kb(678948),
n_layers=16,
),
),
@@ -143,9 +143,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B",
storage_size_kilobytes=1765062,
storage_size=Memory.from_kb(1765062),
n_layers=28,
),
),
@@ -157,9 +157,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size_kilobytes=38758160,
storage_size=Memory.from_kb(38758160),
n_layers=80,
),
),
@@ -171,9 +171,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
pretty_name="Phi 3 Mini 128k",
storage_size_kilobytes=2099262,
storage_size=Memory.from_kb(2099262),
n_layers=32,
),
),
@@ -184,9 +184,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
pretty_name="Phi 3 Mini 128k",
storage_size_kilobytes=2099262,
storage_size=Memory.from_kb(2099262),
n_layers=32,
),
),
@@ -198,9 +198,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Qwen3-0.6B-4bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B",
storage_size_kilobytes=327512,
storage_size=Memory.from_kb(327512),
n_layers=28,
),
),
@@ -211,9 +211,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Qwen3-30B-A3B-4bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B (Active 3B)",
storage_size_kilobytes=16772092,
storage_size=Memory.from_kb(16772092),
n_layers=48,
),
),
@@ -225,9 +225,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/granite-3.3-2b-instruct-fp16",
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
pretty_name="Granite 3.3 2B",
storage_size_kilobytes=4948320,
storage_size=Memory.from_kb(4948320),
n_layers=40,
),
),
@@ -238,9 +238,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/granite-3.3-8b-instruct-fp16",
model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
pretty_name="Granite 3.3 8B",
storage_size_kilobytes=15958720,
storage_size=Memory.from_kb(15958720),
n_layers=40,
),
),
@@ -252,9 +252,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/SmolLM-135M-4bit",
model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
pretty_name="Smol LM 135M",
storage_size_kilobytes=73940,
storage_size=Memory.from_kb(73940),
n_layers=30,
),
),

View File

@@ -6,7 +6,8 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.types.models import ModelMetadata
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
@@ -65,7 +66,7 @@ async def get_config_data(model_id: str) -> ConfigData:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> int:
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
@@ -83,12 +84,12 @@ async def get_safetensors_size(model_id: str) -> int:
metadata = index_data.metadata
if metadata is not None:
return metadata.total_size
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return info.safetensors.total
return Memory.from_bytes(info.safetensors.total)
_model_meta_cache: Dict[str, ModelMetadata] = {}
@@ -109,8 +110,8 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
mem_size_bytes = await get_safetensors_size(model_id)
return ModelMetadata(
model_id=model_id,
model_id=ModelId(model_id),
pretty_name=model_id,
storage_size_kilobytes=mem_size_bytes // 1024,
storage_size=mem_size_bytes,
n_layers=num_layers,
)

View File

@@ -0,0 +1,313 @@
import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.common import NodeId
from exo.utils.channels import channel
# ======= #
# Helpers #
# ======= #
def em(clock: int, seniority: int, node_id: str) -> ElectionMessage:
return ElectionMessage(clock=clock, seniority=seniority, node_id=NodeId(node_id))
@pytest.fixture
def fast_timeout(monkeypatch: pytest.MonkeyPatch):
# Keep campaigns fast; user explicitly allows tests to shorten the timeout.
import exo.shared.election as election_mod
monkeypatch.setattr(election_mod, "ELECTION_TIMEOUT", 0.05, raising=True)
yield
# ======================================= #
# TESTS #
# ======================================= #
@pytest.mark.anyio
async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
fast_timeout: None,
) -> None:
"""
Start a round by injecting an ElectionMessage with higher clock.
With only our node effectively 'winning', we should broadcast once and update seniority.
"""
# Outbound election messages from the Election (we'll observe these)
em_out_tx, em_out_rx = channel[ElectionMessage]()
# Inbound election messages to the Election (we'll inject these)
em_in_tx, em_in_rx = channel[ElectionMessage]()
# Election results produced by the Election (we'll observe these)
er_tx, er_rx = channel[ElectionResult]()
# Connection messages (unused in this test but required by ctor)
cm_tx, cm_rx = channel[ConnectionMessage]()
election = Election(
node_id=NodeId("B"),
election_message_receiver=em_in_rx,
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
is_candidate=True,
)
async with create_task_group() as tg:
with fail_after(2):
tg.start_soon(election.run)
# Trigger new round at clock=1 (peer announces it)
await em_in_tx.send(em(clock=1, seniority=0, node_id="A"))
# Expect our broadcast back to the peer side for this round only
while True:
got = await em_out_rx.receive()
if got.clock == 1 and got.node_id == NodeId("B"):
break
# Wait for the round to finish and produce an ElectionResult
result = await er_rx.receive()
assert result.node_id == NodeId("B")
# We spawned as master; electing ourselves again is not "new master".
assert result.is_new_master is False
# Close inbound streams to end the receivers (and run())
await em_in_tx.aclose()
await cm_tx.aclose()
# We should have updated seniority to 2 (A + B).
assert election.seniority == 2
@pytest.mark.anyio
async def test_peer_with_higher_seniority_wins_and_we_switch_master(
fast_timeout: None,
) -> None:
"""
If a peer with clearly higher seniority participates in the round, they should win.
We should broadcast our status exactly once for this round, then switch master.
"""
em_out_tx, em_out_rx = channel[ElectionMessage]()
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
election = Election(
node_id=NodeId("ME"),
election_message_receiver=em_in_rx,
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
is_candidate=True,
)
async with create_task_group() as tg:
with fail_after(2):
tg.start_soon(election.run)
# Start round with peer's message (higher seniority)
await em_in_tx.send(em(clock=1, seniority=10, node_id="PEER"))
# We should still broadcast our status exactly once for this round
while True:
got = await em_out_rx.receive()
if got.clock == 1:
assert got.seniority == 0
break
# After the timeout, election result should report the peer as master
result = await er_rx.receive()
assert result.node_id == NodeId("PEER")
assert result.is_new_master is True
await em_in_tx.aclose()
await cm_tx.aclose()
# We lost → seniority unchanged
assert election.seniority == 0
@pytest.mark.anyio
async def test_ignores_older_messages(fast_timeout: None) -> None:
"""
Messages with a lower clock than the current round are ignored by the receiver.
Expect exactly one broadcast for the higher clock round.
"""
em_out_tx, em_out_rx = channel[ElectionMessage]()
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, _er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
election = Election(
node_id=NodeId("ME"),
election_message_receiver=em_in_rx,
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
is_candidate=True,
)
async with create_task_group() as tg:
with fail_after(2):
tg.start_soon(election.run)
# Newer round arrives first -> triggers campaign at clock=2
await em_in_tx.send(em(clock=2, seniority=0, node_id="A"))
while True:
first = await em_out_rx.receive()
if first.clock == 2:
break
# Older message (clock=1) must be ignored (no second broadcast)
await em_in_tx.send(em(clock=1, seniority=999, node_id="B"))
got_second = False
with move_on_after(0.2):
_ = await em_out_rx.receive()
got_second = True
assert not got_second, "Should not receive a broadcast for an older round"
await em_in_tx.aclose()
await cm_tx.aclose()
# Not asserting on the result; focus is on ignore behavior.
@pytest.mark.anyio
async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
fast_timeout: None,
) -> None:
"""
Two successive rounds → two broadcasts. Second round triggered by a higher-clock message.
"""
em_out_tx, em_out_rx = channel[ElectionMessage]()
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, _er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
election = Election(
node_id=NodeId("ME"),
election_message_receiver=em_in_rx,
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
is_candidate=True,
)
async with create_task_group() as tg:
with fail_after(2):
tg.start_soon(election.run)
# Round 1 at clock=1
await em_in_tx.send(em(clock=1, seniority=0, node_id="X"))
while True:
m1 = await em_out_rx.receive()
if m1.clock == 1:
break
# Round 2 at clock=2
await em_in_tx.send(em(clock=2, seniority=0, node_id="Y"))
while True:
m2 = await em_out_rx.receive()
if m2.clock == 2:
break
await em_in_tx.aclose()
await cm_tx.aclose()
# Not asserting on who won; just that both rounds were broadcast.
@pytest.mark.anyio
async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -> None:
"""
When we win against two peers in the same round, our seniority becomes
max(existing, number_of_candidates). With existing=0: expect 3 (us + A + B).
"""
em_out_tx, em_out_rx = channel[ElectionMessage]()
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
election = Election(
node_id=NodeId("ME"),
election_message_receiver=em_in_rx,
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
is_candidate=True,
)
async with create_task_group() as tg:
with fail_after(2):
tg.start_soon(election.run)
# Start round at clock=7 with two peer participants
await em_in_tx.send(em(clock=7, seniority=0, node_id="A"))
await em_in_tx.send(em(clock=7, seniority=0, node_id="B"))
# We should see exactly one broadcast from us for this round
while True:
got = await em_out_rx.receive()
if got.clock == 7 and got.node_id == NodeId("ME"):
break
# Wait for the election to finish so seniority updates
_ = await er_rx.receive()
await em_in_tx.aclose()
await cm_tx.aclose()
# We + A + B = 3 → new seniority expected to be 3
assert election.seniority == 3
@pytest.mark.anyio
async def test_connection_message_triggers_new_round_broadcast(
fast_timeout: None,
) -> None:
"""
A connection message increments the clock and starts a new campaign.
We should observe a broadcast at the incremented clock.
"""
em_out_tx, em_out_rx = channel[ElectionMessage]()
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, _er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
election = Election(
node_id=NodeId("ME"),
election_message_receiver=em_in_rx,
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
is_candidate=True,
)
async with create_task_group() as tg:
with fail_after(2):
tg.start_soon(election.run)
# Send any connection message object; we close quickly to cancel before result creation
await cm_tx.send(
ConnectionMessage(
node_id=NodeId(),
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
)
)
# Expect a broadcast for the new round at clock=1
while True:
got = await em_out_rx.receive()
if got.clock == 1 and got.node_id == NodeId("ME"):
break
# Close promptly to avoid waiting for campaign completion
await em_in_tx.aclose()
await cm_tx.aclose()
# After cancellation (before election finishes), no seniority changes asserted here.

View File

@@ -1,7 +1,7 @@
import pytest
from exo.shared.ipc.file_mutex.flock_mutex import FlockMutex, LockType
from exo.shared.utils.fs import delete_if_exists, make_temp_path
from exo.utils.fs import delete_if_exists, make_temp_path
def test_lock_held():

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import contextlib
import logging
import multiprocessing
@@ -13,8 +11,8 @@ from typing import Optional
from pytest import LogCaptureFixture
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.shared.keypair import get_node_id_keypair
NUM_CONCURRENT_PROCS = 10

View File

@@ -1,612 +0,0 @@
import asyncio
import json
import tempfile
from pathlib import Path
from typing import Any, Generator, cast
from uuid import uuid4
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from exo.shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import ChunkGenerated
from exo.shared.types.events.chunks import ChunkType, TokenChunk
# Type ignore comment for all protected member access in this test file
# pyright: reportPrivateUsage=false
def _load_json_data(raw_data: str) -> dict[str, Any]:
"""Helper function to load JSON data with proper typing."""
return cast(dict[str, Any], json.loads(raw_data))
@pytest.fixture
def temp_db_path() -> Generator[Path, None, None]:
"""Create a temporary database file for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
yield Path(f.name)
# Cleanup
Path(f.name).unlink(missing_ok=True)
@pytest.fixture
def sample_node_id() -> NodeId:
"""Create a sample NodeId for testing."""
return NodeId()
class TestAsyncSQLiteEventStorage:
"""Test suite for AsyncSQLiteEventStorage focused on storage functionality."""
@pytest.mark.asyncio
async def test_initialization_creates_tables(self, temp_db_path: Path) -> None:
"""Test that database initialization creates the events table."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
# Verify table exists by querying directly
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text(
"SELECT name FROM sqlite_master WHERE type='table' AND name='events'"
)
)
tables = result.fetchall()
assert len(tables) == 1
assert tables[0][0] == "events"
await storage.close()
@pytest.mark.asyncio
async def test_start_twice_raises_error(self, temp_db_path: Path) -> None:
"""Test that starting storage twice raises an error."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
with pytest.raises(RuntimeError, match="Storage already started"):
await storage.start()
await storage.close()
@pytest.mark.asyncio
async def test_direct_database_operations(
self, temp_db_path: Path, sample_node_id: NodeId
) -> None:
"""Test direct database operations without event parsing."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
# Insert test data directly
test_data = {
"event_type": "test_event",
"test_field": "test_value",
"number": 42,
}
async with AsyncSession(storage._engine) as session:
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": sample_node_id,
"event_type": "test_event",
"event_id": str(uuid4()),
"event_data": json.dumps(test_data),
},
)
await session.commit()
# Query data back
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid")
)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0][0] == 1 # rowid
assert rows[0][1] == sample_node_id # origin
raw_json = cast(str, rows[0][2])
retrieved_data = _load_json_data(raw_json)
assert retrieved_data == test_data
await storage.close()
@pytest.mark.asyncio
async def test_rowid_auto_increment(
self, temp_db_path: Path, sample_node_id: NodeId
) -> None:
"""Test that rowid auto-increments correctly."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
# Insert multiple records
test_records = [
{"event_type": "test_event_1", "data": "first"},
{"event_type": "test_event_2", "data": "second"},
{"event_type": "test_event_3", "data": "third"},
]
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
for record in test_records:
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": sample_node_id,
"event_type": record["event_type"],
"event_id": str(uuid4()),
"event_data": json.dumps(record),
},
)
await session.commit()
# Query back and verify rowid sequence
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text("SELECT rowid, event_data FROM events ORDER BY rowid")
)
rows = result.fetchall()
assert len(rows) == 3
for i, row in enumerate(rows):
assert row[0] == i + 1 # rowid starts at 1
raw_json = cast(str, row[1])
retrieved_data = _load_json_data(raw_json)
assert retrieved_data == test_records[i]
await storage.close()
@pytest.mark.asyncio
async def test_get_last_idx(
self, temp_db_path: Path, sample_node_id: NodeId
) -> None:
"""Test that rowid returns correctly from db."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
# Insert multiple records
test_records = [
{"event_type": "test_event_1", "data": "first"},
{"event_type": "test_event_2", "data": "second"},
{"event_type": "test_event_3", "data": "third"},
]
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
for record in test_records:
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": sample_node_id,
"event_type": record["event_type"],
"event_id": str(uuid4()),
"event_data": json.dumps(record),
},
)
await session.commit()
last_idx = await storage.get_last_idx()
assert last_idx == 3
await storage.close()
@pytest.mark.asyncio
async def test_rowid_with_multiple_origins(self, temp_db_path: Path) -> None:
"""Test rowid sequence across multiple origins."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
origin1 = NodeId()
origin2 = NodeId()
# Insert interleaved records from different origins
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
# Origin 1 - record 1
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": origin1,
"event_type": "event_1",
"event_id": str(uuid4()),
"event_data": json.dumps({"from": "origin1", "seq": 1}),
},
)
# Origin 2 - record 2
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": origin2,
"event_type": "event_2",
"event_id": str(uuid4()),
"event_data": json.dumps({"from": "origin2", "seq": 2}),
},
)
# Origin 1 - record 3
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": origin1,
"event_type": "event_3",
"event_id": str(uuid4()),
"event_data": json.dumps({"from": "origin1", "seq": 3}),
},
)
await session.commit()
# Verify sequential rowid regardless of origin
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid")
)
rows = result.fetchall()
assert len(rows) == 3
assert rows[0][0] == 1 # First rowid
assert rows[1][0] == 2 # Second rowid
assert rows[2][0] == 3 # Third rowid
# Verify data integrity
raw_json1 = cast(str, rows[0][2])
raw_json2 = cast(str, rows[1][2])
raw_json3 = cast(str, rows[2][2])
data1 = _load_json_data(raw_json1)
data2 = _load_json_data(raw_json2)
data3 = _load_json_data(raw_json3)
assert data1["from"] == "origin1" and data1["seq"] == 1
assert data2["from"] == "origin2" and data2["seq"] == 2
assert data3["from"] == "origin1" and data3["seq"] == 3
await storage.close()
@pytest.mark.asyncio
async def test_query_events_since_index(
self, temp_db_path: Path, sample_node_id: NodeId
) -> None:
"""Test querying events after a specific rowid."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
# Insert 10 test records
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
for i in range(10):
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": sample_node_id,
"event_type": f"event_{i}",
"event_id": str(uuid4()),
"event_data": json.dumps({"index": i}),
},
)
await session.commit()
# Query events after index 5
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text(
"SELECT rowid, event_data FROM events WHERE rowid > :last_idx ORDER BY rowid"
),
{"last_idx": 5},
)
rows = result.fetchall()
assert len(rows) == 5 # Should get records 6-10
for i, row in enumerate(rows):
assert row[0] == i + 6 # rowid 6, 7, 8, 9, 10
raw_json = cast(str, row[1])
data = _load_json_data(raw_json)
assert data["index"] == i + 5 # index 5, 6, 7, 8, 9
await storage.close()
@pytest.mark.asyncio
async def test_empty_query(self, temp_db_path: Path) -> None:
"""Test querying when no events exist."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text(
"SELECT rowid, origin, event_data FROM events WHERE rowid > :last_idx ORDER BY rowid"
),
{"last_idx": 0},
)
rows = result.fetchall()
assert len(rows) == 0
await storage.close()
@pytest.mark.asyncio
async def test_operations_after_close_raise_error(self, temp_db_path: Path) -> None:
"""Test that operations after close work properly."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
await storage.close()
# These should not raise errors since we're not using the public API
assert storage._closed is True
assert storage._engine is not None # Engine should still exist but be disposed
@pytest.mark.asyncio
async def test_multiple_close_calls_safe(self, temp_db_path: Path) -> None:
"""Test that multiple close calls are safe."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
await storage.close()
await storage.close() # Should not raise an error
@pytest.mark.asyncio
async def test_json_data_types(
self, temp_db_path: Path, sample_node_id: NodeId
) -> None:
"""Test that various JSON data types are handled correctly."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
# Test various JSON data types
test_data = {
"string": "test string",
"number": 42,
"float": 3.14,
"boolean": True,
"null": None,
"array": [1, 2, 3, "four"],
"object": {"nested": "value", "deep": {"deeper": "nested"}},
"unicode": "测试 🚀",
}
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": sample_node_id,
"event_type": "complex_event",
"event_id": str(uuid4()),
"event_data": json.dumps(test_data),
},
)
await session.commit()
# Query back and verify data integrity
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text("SELECT event_data FROM events WHERE event_type = :event_type"),
{"event_type": "complex_event"},
)
rows = result.fetchall()
assert len(rows) == 1
raw_json = cast(str, rows[0][0])
retrieved_data = _load_json_data(raw_json)
assert retrieved_data == test_data
await storage.close()
@pytest.mark.asyncio
async def test_concurrent_inserts(self, temp_db_path: Path) -> None:
"""Test concurrent inserts maintain rowid ordering."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
async def insert_batch(origin_id: str, batch_id: int, count: int) -> None:
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
for i in range(count):
await session.execute(
text(
"INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"
),
{
"origin": origin_id,
"event_type": f"batch_{batch_id}_event_{i}",
"event_id": str(uuid4()),
"event_data": json.dumps({"batch": batch_id, "item": i}),
},
)
await session.commit()
# Run multiple concurrent insert batches
origin1 = str(uuid4())
origin2 = str(uuid4())
origin3 = str(uuid4())
await asyncio.gather(
insert_batch(origin1, 1, 5),
insert_batch(origin2, 2, 5),
insert_batch(origin3, 3, 5),
)
# Verify all records were inserted and rowid is sequential
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
result = await session.execute(
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid")
)
rows = result.fetchall()
assert len(rows) == 15 # 3 batches * 5 records each
# Verify rowid sequence is maintained
for i, row in enumerate(rows):
assert row[0] == i + 1 # rowid should be sequential
await storage.close()
@pytest.mark.asyncio
async def test_chunk_generated_event_serialization(
self, temp_db_path: Path, sample_node_id: NodeId
) -> None:
"""Test that ChunkGenerated event with nested types can be serialized and deserialized correctly."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(
db_path=temp_db_path,
batch_size=default_config.batch_size,
batch_timeout_ms=default_config.batch_timeout_ms,
debounce_ms=default_config.debounce_ms,
max_age_ms=default_config.max_age_ms,
)
await storage.start()
# Create a ChunkGenerated event with nested TokenChunk
command_id = CommandId()
token_chunk = TokenChunk(
text="Hello, world!",
token_id=42,
finish_reason="stop",
chunk_type=ChunkType.token,
command_id=command_id,
idx=0,
model="test-model",
)
chunk_generated_event = ChunkGenerated(command_id=command_id, chunk=token_chunk)
# Store the event using the storage API
await storage.append_events([chunk_generated_event], sample_node_id)
# Wait for batch to be written
await asyncio.sleep(0.5)
# Retrieve the event
events = await storage.get_events_since(0)
# Verify we got the event back
assert len(events) == 1
retrieved_event_wrapper = events[0]
assert retrieved_event_wrapper.origin == sample_node_id
# Verify the event was deserialized correctly
retrieved_event = retrieved_event_wrapper.event
assert isinstance(retrieved_event, ChunkGenerated)
assert retrieved_event.command_id == command_id
# Verify the nested chunk was deserialized correctly
retrieved_chunk = retrieved_event.chunk
assert isinstance(retrieved_chunk, TokenChunk)
assert retrieved_chunk.chunk_type == ChunkType.token
assert retrieved_chunk.command_id == command_id
assert retrieved_chunk.idx == 0
assert retrieved_chunk.model == "test-model"
# Verify the chunk data
assert retrieved_chunk.text == "Hello, world!"
assert retrieved_chunk.token_id == 42
assert retrieved_chunk.finish_reason == "stop"
await storage.close()

Some files were not shown because too many files have changed in this diff Show More