diff --git a/.flake-modules/go-forwarder.nix b/.flake-modules/go-forwarder.nix deleted file mode 100644 index 34a38cf1..00000000 --- a/.flake-modules/go-forwarder.nix +++ /dev/null @@ -1,71 +0,0 @@ -# Configures the Golang support and builds the forwarder -# TODO: split this up in the future as this is unrelated tasks?? - -# Top-level parameters that are bound to the provider flake -# These are passed from `flake.nix` using importApply -{ - localSelf, - flake-parts-lib, - nixpkgs-lib, - ... -}: - -# These values would bind to the consumer flake when this flake module is imported: -{ - config, - self, - inputs, - getSystem, - moduleWithSystem, - withSystem, - ... -}: - -# The actual flake-parts module configuration -{ - perSystem = - { - config, - self', - inputs', - pkgs, - system, - ... - }: - let -# flakeRoot = nixpkgs-lib.getExe config.flake-root.package; -# -# # Build the networking/forwarder Go utility. -# forwarder = pkgs.buildGoModule { -# pname = "exo-forwarder"; -# version = "0.1.0"; -# src = "${flakeRoot}/networking/forwarder"; -# -# vendorHash = "sha256-BXIGg2QYqHDz2TNe8hLAGC6jVlffp9766H+WdkkuVgA="; -# -# # Only the main package at the repository root needs building. -# subPackages = [ "." ]; -# }; - in - { - packages = { -# inherit forwarder; - }; - - apps = { -# forwarder = { -# type = "app"; -# program = "${forwarder}/bin/forwarder"; -# }; - }; - - make-shells.default = { - # Go 1.24 compiler – align with go.mod - packages = [ pkgs.go_1_24 ]; - shellHook = '' - GOPATH="''$(${nixpkgs-lib.getExe config.flake-root.package})"/.go_cache - export GOPATH - ''; - }; - }; -} diff --git a/.flake-modules/just-flake.nix b/.flake-modules/just-flake.nix deleted file mode 100644 index 2208a58c..00000000 --- a/.flake-modules/just-flake.nix +++ /dev/null @@ -1,54 +0,0 @@ -# Provides pretty banner & command index for this flake - -# Top-level parameters that are bound to the provider flake -# These are passed from `flake.nix` using importApply -{ - localSelf, - flake-parts-lib, - nixpkgs-lib, - just-flake, - ... -}: - -# These values would bind to the consumer flake when this flake module is imported: -{ - config, - self, - inputs, - getSystem, - moduleWithSystem, - withSystem, - ... -}: - -# The actual flake-parts module configuration -{ - imports = [ just-flake.flakeModule ]; - perSystem = - { - config, - self', - inputs', - pkgs, - system, - ... - }: - { - just-flake.features = { - # treefmt.enable = true; - # rust.enable = true; - # convco.enable = true; - # hello = { - # enable = true; - # justfile = '' - # hello: - # echo Hello World - # ''; - # }; - }; - - make-shells.default = { - inputsFrom = [ config.just-flake.outputs.devShell ]; - }; - }; -} diff --git a/.flake-modules/macmon.nix b/.flake-modules/macmon.nix deleted file mode 100644 index 5df0cdf4..00000000 --- a/.flake-modules/macmon.nix +++ /dev/null @@ -1,30 +0,0 @@ -# Provides macmon binary for the worker. - -# These values would bind to the consumer flake when this flake module is imported: -{ - config, - self, - inputs, - getSystem, - moduleWithSystem, - withSystem, - ... -}: - -# The actual flake-parts module configuration -{ - perSystem = - { - config, - self', - inputs', - pkgs, - system, - ... - }: - { - make-shells.default = { - packages = if (system == "aarch64-darwin") then ([ pkgs.macmon ]) else ([]); - }; - }; -} diff --git a/.gitignore b/.gitignore index 310df30d..19b4dd09 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,6 @@ hosts*.json # TODO figure out how to properly solve the issue with these target directories showing up networking/target/ networking/topology/target/ -rust/target/ -rust/Cargo.lock build/ dist/ @@ -26,4 +24,4 @@ dist/ just-flake.just # for the gitingest enthusiasts -digest.txt \ No newline at end of file +digest.txt diff --git a/.idea/exo-v2.iml b/.idea/exo-v2.iml index 5357eaa9..aa638174 100644 --- a/.idea/exo-v2.iml +++ b/.idea/exo-v2.iml @@ -10,11 +10,19 @@ + + + + + + + + diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 94a25f7f..35eb1ddf 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -1,6 +1,6 @@ - + \ No newline at end of file diff --git a/copy_model.sh b/copy_model.sh new file mode 100755 index 00000000..f5c985aa --- /dev/null +++ b/copy_model.sh @@ -0,0 +1,133 @@ +#!/usr/bin/env bash +set -euo pipefail + +# copy_model.sh: clone ~/.exo/models from SOURCE to one or more TARGETS using scp -3. +# Username defaults: +# - If host is "aN" and no user given, username defaults to "aN". +# - Otherwise defaults to $(whoami), unless you pass user@host. +# +# Examples: +# ./copy_model.sh a1 a2 a3 +# ./copy_model.sh a1 frank@a2 192.168.1.3 + +if [ $# -lt 2 ]; then + echo "Usage: $0 SOURCE TARGET [TARGET...]" >&2 + exit 2 +fi + +SOURCE="$1" +shift +TARGETS=("$@") + +DEFAULT_USER="$(whoami)" +MODELS_REL=".exo/models" # relative under $HOME + +timestamp() { date "+%Y-%m-%d %H:%M:%S"; } + +split_user_host() { + local in="$1" + if [[ "$in" == *"@"* ]]; then + printf "%s|%s" "${in%%@*}" "${in#*@}" + else + printf "|%s" "$in" + fi +} + +resolve_ip() { + local hostish="$1" + if [[ "$hostish" =~ ^a([0-9]+)$ ]]; then + echo "192.168.1.${BASH_REMATCH[1]}" + else + echo "$hostish" + fi +} + +default_user_for() { + local hostish="$1" + if [[ "$hostish" =~ ^a([0-9]+)$ ]]; then + echo "$hostish" + else + echo "$DEFAULT_USER" + fi +} + +SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR -o ConnectTimeout=10) +SSHPASS_BIN="$(command -v sshpass || true)" +SCP_BIN="${SCP_BIN:-scp}" + +read -s -p "Password for all hosts: " PASS +echo +if [ -n "$SSHPASS_BIN" ]; then + echo "$(timestamp) sshpass found: will provide the password non-interactively." +else + echo "$(timestamp) WARNING: sshpass not found — you’ll be prompted by scp/ssh per hop unless keys are set up." +fi + +# Build source endpoint (default username logic) +IFS='|' read -r SRC_USER_RAW SRC_HOSTISH <<<"$(split_user_host "$SOURCE")" +SRC_USER="${SRC_USER_RAW:-$(default_user_for "$SRC_HOSTISH")}" +SRC_IP="$(resolve_ip "$SRC_HOSTISH")" +SRC_HOST="${SRC_USER}@${SRC_IP}" + +echo "$(timestamp) Source: ${SRC_HOST}:~/${MODELS_REL}" +echo "$(timestamp) Targets: ${#TARGETS[@]}" + +# Helper to run a simple remote command via ssh (for mkdir -p checks) +ssh_run() { + local host="$1" + shift + if [ -n "$SSHPASS_BIN" ]; then + sshpass -p "$PASS" ssh "${SSH_OPTS[@]}" "$host" "$@" + else + ssh "${SSH_OPTS[@]}" "$host" "$@" + fi +} + +# Ensure source dir exists (create if missing, per your request) +ssh_run "$SRC_HOST" "mkdir -p ~/${MODELS_REL}" + +failures=0 +count=0 +for T in "${TARGETS[@]}"; do + count=$((count + 1)) + IFS='|' read -r T_USER_RAW T_HOSTISH <<<"$(split_user_host "$T")" + T_USER="${T_USER_RAW:-$(default_user_for "$T_HOSTISH")}" + T_IP="$(resolve_ip "$T_HOSTISH")" + T_HOST="${T_USER}@${T_IP}" + + echo "============================================================" + echo "$(timestamp) [${count}/${#TARGETS[@]}] ${SRC_HOST} ==> ${T_HOST}" + echo "$(timestamp) Ensuring destination directory exists…" + ssh_run "$T_HOST" "mkdir -p ~/${MODELS_REL%/*}" # ~/.exo + + # Copy the whole "models" directory into ~/.exo on the target. + # scp -3 = copy between two remotes via local; -r recursive; -p preserve times/modes + if [ -n "$SSHPASS_BIN" ]; then + echo "$(timestamp) Running: scp -3 -rp ${SRC_HOST}:~/${MODELS_REL} ${T_HOST}:~/.exo/" + if sshpass -p "$PASS" "$SCP_BIN" "${SSH_OPTS[@]}" -3 -rp \ + "${SRC_HOST}:~/${MODELS_REL}" \ + "${T_HOST}:~/.exo/"; then + echo "$(timestamp) [${count}] Done: ${T_HOST}" + else + echo "$(timestamp) [${count}] ERROR during scp to ${T_HOST}" >&2 + failures=$((failures + 1)) + fi + else + echo "$(timestamp) Running: scp -3 -rp ${SRC_HOST}:~/${MODELS_REL} ${T_HOST}:~/.exo/" + if "$SCP_BIN" "${SSH_OPTS[@]}" -3 -rp \ + "${SRC_HOST}:~/${MODELS_REL}" \ + "${T_HOST}:~/.exo/"; then + echo "$(timestamp) [${count}] Done: ${T_HOST}" + else + echo "$(timestamp) [${count}] ERROR during scp to ${T_HOST}" >&2 + failures=$((failures + 1)) + fi + fi +done + +echo "============================================================" +if [ "$failures" -eq 0 ]; then + echo "$(timestamp) All transfers completed successfully." +else + echo "$(timestamp) Completed with ${failures} failure(s)." +fi diff --git a/dashboard/index.html b/dashboard/index.html index 433746fe..85f94589 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -943,7 +943,7 @@ } const result = await response.json(); - showLaunchStatus(`Instance launched successfully: ${result.instance_id}`, 'success'); + showLaunchStatus('Instance launched successfully'); // Reset form modelSelect.value = ''; diff --git a/flake.lock b/flake.lock index bc30d2b3..35076eff 100644 --- a/flake.lock +++ b/flake.lock @@ -1,5 +1,26 @@ { "nodes": { + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1755585599, + "narHash": "sha256-tl/0cnsqB/Yt7DbaGMel2RLa7QG5elA8lkaOXli6VdY=", + "owner": "nix-community", + "repo": "fenix", + "rev": "6ed03ef4c8ec36d193c18e06b9ecddde78fb7e42", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "type": "github" + } + }, "flake-compat": { "flake": false, "locked": { @@ -102,12 +123,30 @@ }, "root": { "inputs": { + "fenix": "fenix", "flake-parts": "flake-parts", "flake-root": "flake-root", "just-flake": "just-flake", "make-shell": "make-shell", "nixpkgs": "nixpkgs" } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1755504847, + "narHash": "sha256-VX0B9hwhJypCGqncVVLC+SmeMVd/GAYbJZ0MiiUn2Pk=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "a905e3b21b144d77e1b304e49f3264f6f8d4db75", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 0098a869..b1f69a86 100644 --- a/flake.nix +++ b/flake.nix @@ -20,47 +20,39 @@ # Provides flake integration with [Just](https://just.systems/man/en/) just-flake.url = "github:juspay/just-flake"; + + # Provides Rust dev-env integration: + fenix = { + url = "github:nix-community/fenix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; }; + # TODO: figure out caching story + # nixConfig = { + # # nix community cachix + # extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs="; + # extra-substituters = "https://nix-community.cachix.org"; + # }; + outputs = inputs@{ flake-parts, ... }: flake-parts.lib.mkFlake { inherit inputs; } ( - { - flake-parts-lib, - self, - ... - }: - let - nixpkgs-lib = inputs.nixpkgs.lib; - - # A wraper around importApply that supplies default parameters - importApply' = - path: extraParams: - (flake-parts-lib.importApply path ( - nixpkgs-lib.recursiveUpdate { - localSelf = self; - inherit flake-parts-lib; - inherit nixpkgs-lib; - } extraParams - )); - - # instantiate all the flake modules, passing custom arguments to them as needed - flakeModules = { - flakeRoot = importApply' ./.flake-modules/flake-root.nix { inherit (inputs) flake-root; }; - justFlake = importApply' ./.flake-modules/just-flake.nix { inherit (inputs) just-flake; }; - goForwarder = importApply' ./.flake-modules/go-forwarder.nix { }; - }; - in + { flake-parts-lib, self, ... }: { imports = [ inputs.make-shell.flakeModules.default - flakeModules.flakeRoot - flakeModules.justFlake - flakeModules.goForwarder - ./.flake-modules/macmon.nix + + ./nix/modules/pkgs-init.nix # nixpkgs overlays manager + ./nix/modules/flake-root.nix + ./nix/modules/just-flake.nix + ./nix/modules/macmon.nix + ./nix/modules/python.nix + ./nix/modules/rust.nix + ./nix/modules/go-forwarder.nix ]; systems = [ "x86_64-linux" @@ -75,55 +67,31 @@ system, ... }: - let - buildInputs = with pkgs; [ - ]; - nativeBuildInputs = with pkgs; [ - ]; - in { # Per-system attributes can be defined here. The self' and inputs' # module parameters provide easy access to attributes of the same # system. # NOTE: pkgs is equivalent to inputs'.nixpkgs.legacyPackages.hello; - apps = { - python-lsp = { - type = "app"; - program = "${pkgs.basedpyright}/bin/basedpyright-langserver"; - }; - default = self'.apps.forwarder; - }; + apps = { }; make-shells.default = { packages = [ - pkgs.python313 - pkgs.uv pkgs.protobuf - pkgs.basedpyright - pkgs.ruff ]; - nativeBuildInputs = - with pkgs; - [ - nixpkgs-fmt - cmake - ] - ++ buildInputs - ++ nativeBuildInputs; - - # Arguments which are intended to be environment variables in the shell environment - # should be changed to attributes of the `env` option - env = { - # fixes libstdc++.so issues and libgl.so issues - LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib"; - }; + nativeBuildInputs = with pkgs; [ + nixpkgs-fmt + ]; shellHook = '' export GO_BUILD_DIR=$(git rev-parse --show-toplevel)/build; export DASHBOARD_DIR=$(git rev-parse --show-toplevel)/dashboard; ''; + # Arguments which are intended to be environment variables in the shell environment + # should be changed to attributes of the `env` option + env = { }; + # Arbitrary mkDerivation arguments should be changed to be attributes of the `additionalArguments` option additionalArguments = { }; }; diff --git a/kill_remote.sh b/kill_remote.sh new file mode 100755 index 00000000..727b3261 --- /dev/null +++ b/kill_remote.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +set -euo pipefail + +############################################################################### +# Args & prerequisites +############################################################################### +if [[ $# -gt 1 ]]; then + echo "Usage: $0 [hosts_file]" >&2 + exit 1 +fi +HOSTS_FILE=${1:-hosts.txt} + +############################################################################### +# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+) +############################################################################### +if [[ ! -f "$HOSTS_FILE" ]]; then + echo "Error: $HOSTS_FILE not found" + exit 1 +fi + +if builtin command -v mapfile >/dev/null 2>&1; then + mapfile -t HOSTS <"$HOSTS_FILE" +else + HOSTS=() + while IFS= read -r h; do + [[ -n "$h" ]] && HOSTS+=("$h") + done <"$HOSTS_FILE" +fi +[[ ${#HOSTS[@]} -gt 0 ]] || { + echo "No hosts found in $HOSTS_FILE" + exit 1 +} + +############################################################################### +# Helper – run a remote command and capture rc/stderr/stdout +############################################################################### +ssh_opts=(-o StrictHostKeyChecking=no + -o LogLevel=ERROR) + +run_remote() { # $1 host $2 command + local host=$1 cmd=$2 rc + if ssh "${ssh_opts[@]}" "$host" "$cmd"; then + rc=0 + else + rc=$? + fi + return $rc +} + +############################################################################### +# Kill exo everywhere (parallel) +############################################################################### +echo "=== Killing exo on ${#HOSTS[@]} host(s) ===" +fail=0 +for h in "${HOSTS[@]}"; do + ( + run_remote "$h" 'pkill -f exo || true' + ) || fail=1 & +done +wait +((fail == 0)) || { + echo "❌ Some hosts could not be reached—check SSH access." + exit 1 +} +echo "✓ exo processes killed on all reachable hosts." \ No newline at end of file diff --git a/.flake-modules/flake-root.nix b/nix/modules/flake-root.nix similarity index 55% rename from .flake-modules/flake-root.nix rename to nix/modules/flake-root.nix index 02ca1735..6b000405 100644 --- a/.flake-modules/flake-root.nix +++ b/nix/modules/flake-root.nix @@ -2,39 +2,14 @@ # 1. ${lib.getExe config.flake-root.package} # 2. $FLAKE_ROOT environment-varible -# Top-level parameters that are bound to the provider flake -# These are passed from `flake.nix` using importApply -{ - localSelf, - flake-parts-lib, - nixpkgs-lib, - flake-root, - ... -}: - # These values would bind to the consumer flake when this flake module is imported: -{ - config, - self, - inputs, - getSystem, - moduleWithSystem, - withSystem, - ... -}: +{ inputs, ... }: # The actual flake-parts module configuration { - imports = [ flake-root.flakeModule ]; + imports = [ inputs.flake-root.flakeModule ]; perSystem = - { - config, - self', - inputs', - pkgs, - system, - ... - }: + { config, ... }: { flake-root.projectRootFile = "flake.nix"; # Not necessary, as flake.nix is the default diff --git a/nix/modules/go-forwarder.nix b/nix/modules/go-forwarder.nix new file mode 100644 index 00000000..1ef6857c --- /dev/null +++ b/nix/modules/go-forwarder.nix @@ -0,0 +1,19 @@ +{ + perSystem = + { + config, + pkgs, + lib, + ... + }: + { + make-shells.default = { + # Go 1.24 compiler – align with go.mod + packages = [ pkgs.go_1_24 ]; + shellHook = '' + GOPATH="''$(${lib.getExe config.flake-root.package})"/.go_cache + export GOPATH + ''; + }; + }; +} diff --git a/nix/modules/just-flake.nix b/nix/modules/just-flake.nix new file mode 100644 index 00000000..e7a0d2db --- /dev/null +++ b/nix/modules/just-flake.nix @@ -0,0 +1,26 @@ +# Provides pretty banner & command index for this flake + +{ inputs, ... }: +{ + imports = [ inputs.just-flake.flakeModule ]; + perSystem = + { config, ... }: + { + just-flake.features = { + # treefmt.enable = true; + # rust.enable = true; + # convco.enable = true; + # hello = { + # enable = true; + # justfile = '' + # hello: + # echo Hello World + # ''; + # }; + }; + + make-shells.default = { + inputsFrom = [ config.just-flake.outputs.devShell ]; + }; + }; +} diff --git a/nix/modules/macmon.nix b/nix/modules/macmon.nix new file mode 100644 index 00000000..23fa9457 --- /dev/null +++ b/nix/modules/macmon.nix @@ -0,0 +1,12 @@ +{ + perSystem = + { lib, pkgs, ... }: + lib.mkMerge [ + (lib.mkIf pkgs.stdenv.isDarwin { + make-shells.default = { + packages = [ pkgs.macmon ]; + }; + }) + ]; + +} diff --git a/nix/modules/pkgs-init.nix b/nix/modules/pkgs-init.nix new file mode 100644 index 00000000..f75c5944 --- /dev/null +++ b/nix/modules/pkgs-init.nix @@ -0,0 +1,62 @@ +# Single module responsible for collecting all overlays and instantiating in one go + +{ + flake-parts-lib, + inputs, + self, + specialArgs, + ... +}: +let + inherit (flake-parts-lib) mkPerSystemOption; +in +{ + options.perSystem = mkPerSystemOption ( + { + system, + config, + lib, + options, + pkgs, + self', + ... + }@args: + let + inherit (lib.types) + attrsOf + listOf + submoduleWith + raw + ; + in + { + options.pkgs-init.overlays = lib.mkOption { + description = '' + List of nixpkgs overlays (functions of the form: final: prev: { ... }). + Any module can append. Order matters. + ''; + default = [ ]; + example = [ + (final: prev: { + my-hello = prev.hello; + }) + ]; + type = lib.types.listOf lib.types.unspecified; + }; + options.pkgs-init.importArgs = lib.mkOption { + description = "Extra arguments merged into the nixpkgs import call."; + default = { }; + type = lib.types.attrs; + }; + config = { + _module.args.pkgs = import inputs.nixpkgs ( + { + inherit system; + overlays = config.pkgs-init.overlays; + } + // config.pkgs-init.importArgs + ); + }; + } + ); +} diff --git a/nix/modules/python.nix b/nix/modules/python.nix new file mode 100644 index 00000000..ccda8358 --- /dev/null +++ b/nix/modules/python.nix @@ -0,0 +1,20 @@ +# Configures Python shell + +{ + perSystem = + { pkgs, ... }: + { + make-shells.default = { + packages = [ + pkgs.python313 + pkgs.uv + pkgs.ruff + pkgs.basedpyright + ]; + + shellHook = '' + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${pkgs.python313}/lib + ''; + }; + }; +} diff --git a/nix/modules/rust.nix b/nix/modules/rust.nix new file mode 100644 index 00000000..1eb4865d --- /dev/null +++ b/nix/modules/rust.nix @@ -0,0 +1,25 @@ +# Configures Rust shell + +{ inputs, ... }: +{ + perSystem = + { pkgs, ... }: + { + pkgs-init.overlays = [ + inputs.fenix.overlays.default + ]; + + make-shells.default = { + packages = [ + (pkgs.fenix.complete.withComponents [ + "cargo" + "rustc" + "clippy" + "rustfmt" + "rust-src" + ]) + pkgs.rustup # literally only added to make RustRover happy (otherwise useless) + ]; + }; + }; +} diff --git a/pyproject.toml b/pyproject.toml index ba64ebba..8759a9d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "cobs>=1.2.2", "loguru>=0.7.3", "textual>=5.3.0", + "exo_pyo3_bindings", # rust bindings + "anyio>=4.10.0", + "bidict>=0.23.1", ] [project.scripts] @@ -61,8 +64,12 @@ darwin = [ [tool.uv.workspace] members = [ "scripts", + "rust/exo_pyo3_bindings", ] +[tool.uv.sources] +exo_pyo3_bindings = { workspace = true } + [build-system] requires = ["uv_build>=0.8.9,<0.9.0"] build-backend = "uv_build" @@ -87,7 +94,7 @@ reportUnnecessaryTypeIgnoreComment = "error" pythonVersion = "3.13" pythonPlatform = "Darwin" -exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts"] +exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"] stubPath = "typings" [[tool.basedpyright.executionEnvironments]] diff --git a/remote_git.sh b/remote_git.sh index 5c9c003d..73ce84bd 100755 --- a/remote_git.sh +++ b/remote_git.sh @@ -4,47 +4,49 @@ set -euo pipefail ############################################################################### # Args & prerequisites ############################################################################### -if [[ $# -lt 2 ]]; then - echo "Usage: $0 [git_args...]" >&2 +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [git_args...]" >&2 echo "Examples:" >&2 - echo " $0 mypassword pull" >&2 - echo " $0 mypassword checkout main" >&2 - echo " $0 mypassword status" >&2 - echo " $0 mypassword fetch --all" >&2 + echo " $0 pull" >&2 + echo " $0 checkout main" >&2 + echo " $0 status" >&2 + echo " $0 fetch --all" >&2 exit 1 fi -PASSWORD=$1 -shift # Remove password from args -GIT_CMD="$*" # Remaining args form the git command -HOSTS_FILE=${HOSTS_FILE:-hosts.json} - -for prog in jq sshpass; do - command -v "$prog" >/dev/null || - { echo "Error: $prog not installed."; exit 1; } -done +GIT_CMD="$*" # All args form the git command +HOSTS_FILE=${HOSTS_FILE:-hosts.txt} ############################################################################### -# Load hosts.json (works on macOS Bash 3.2 and Bash 4+) +# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+) ############################################################################### +if [[ ! -f "$HOSTS_FILE" ]]; then + echo "Error: $HOSTS_FILE not found" + exit 1 +fi + if builtin command -v mapfile >/dev/null 2>&1; then - mapfile -t HOSTS < <(jq -r '.[]' "$HOSTS_FILE") + mapfile -t HOSTS <"$HOSTS_FILE" else HOSTS=() - while IFS= read -r h; do HOSTS+=("$h"); done < <(jq -r '.[]' "$HOSTS_FILE") + while IFS= read -r h; do + [[ -n "$h" ]] && HOSTS+=("$h") + done <"$HOSTS_FILE" fi -[[ ${#HOSTS[@]} -gt 0 ]] || { echo "No hosts found in $HOSTS_FILE"; exit 1; } +[[ ${#HOSTS[@]} -gt 0 ]] || { + echo "No hosts found in $HOSTS_FILE" + exit 1 +} ############################################################################### # Helper – run a remote command and capture rc/stderr/stdout ############################################################################### ssh_opts=(-o StrictHostKeyChecking=no - -o NumberOfPasswordPrompts=1 # allow sshpass to answer exactly once - -o LogLevel=ERROR) + -o LogLevel=ERROR) -run_remote () { # $1 host $2 command +run_remote() { # $1 host $2 command local host=$1 cmd=$2 rc - if sshpass -p "$PASSWORD" ssh "${ssh_opts[@]}" "$host" "$cmd"; then + if ssh "${ssh_opts[@]}" "$host" "$cmd"; then rc=0 else rc=$? @@ -72,9 +74,9 @@ done wait echo "" -if (( fail == 0 )); then +if ((fail == 0)); then echo "🎉 Git command executed successfully on all hosts!" else echo "⚠️ Some hosts failed—see above." exit 1 -fi \ No newline at end of file +fi diff --git a/run_remote.sh b/run_remote.sh index 87ee2638..2b654e10 100755 --- a/run_remote.sh +++ b/run_remote.sh @@ -4,38 +4,42 @@ set -euo pipefail ############################################################################### # Args & prerequisites ############################################################################### -if [[ $# -lt 1 || $# -gt 2 ]]; then - echo "Usage: $0 [hosts_file]" >&2 ; exit 1 +if [[ $# -gt 1 ]]; then + echo "Usage: $0 [hosts_file]" >&2 + exit 1 fi -PASSWORD=$1 -HOSTS_FILE=${2:-hosts.json} - -for prog in jq sshpass; do - command -v "$prog" >/dev/null || - { echo "Error: $prog not installed."; exit 1; } -done +HOSTS_FILE=${1:-hosts.txt} ############################################################################### -# Load hosts.json (works on macOS Bash 3.2 and Bash 4+) +# Load hosts.txt (works on macOS Bash 3.2 and Bash 4+) ############################################################################### +if [[ ! -f "$HOSTS_FILE" ]]; then + echo "Error: $HOSTS_FILE not found" + exit 1 +fi + if builtin command -v mapfile >/dev/null 2>&1; then - mapfile -t HOSTS < <(jq -r '.[]' "$HOSTS_FILE") + mapfile -t HOSTS <"$HOSTS_FILE" else HOSTS=() - while IFS= read -r h; do HOSTS+=("$h"); done < <(jq -r '.[]' "$HOSTS_FILE") + while IFS= read -r h; do + [[ -n "$h" ]] && HOSTS+=("$h") + done <"$HOSTS_FILE" fi -[[ ${#HOSTS[@]} -gt 0 ]] || { echo "No hosts found in $HOSTS_FILE"; exit 1; } +[[ ${#HOSTS[@]} -gt 0 ]] || { + echo "No hosts found in $HOSTS_FILE" + exit 1 +} ############################################################################### # Helper – run a remote command and capture rc/stderr/stdout ############################################################################### ssh_opts=(-o StrictHostKeyChecking=no - -o NumberOfPasswordPrompts=1 # allow sshpass to answer exactly once - -o LogLevel=ERROR) + -o LogLevel=ERROR) -run_remote () { # $1 host $2 command +run_remote() { # $1 host $2 command local host=$1 cmd=$2 rc - if sshpass -p "$PASSWORD" ssh "${ssh_opts[@]}" "$host" "$cmd"; then + if ssh "${ssh_opts[@]}" "$host" "$cmd"; then rc=0 else rc=$? @@ -54,26 +58,42 @@ for h in "${HOSTS[@]}"; do ) || fail=1 & done wait -(( fail == 0 )) || { echo "❌ Some hosts could not be reached—check password or SSH access."; exit 1; } +((fail == 0)) || { + echo "❌ Some hosts could not be reached—check SSH access." + exit 1 +} echo "✓ exo processes killed on all reachable hosts." - +# ############################################################################### -# Phase 2 – start new exo processes (parallel, with sudo -S) +# Phase 2 – cleanup database files (parallel) ############################################################################### -echo "=== Stage 2: starting new exo processes ===" +echo "=== Stage 2: cleaning up database files ===" fail=0 -for i in "${!HOSTS[@]}"; do - h=${HOSTS[$i]} - - # one liner that pre-caches sudo and then runs the script - if [[ $i -eq 0 ]]; then - remote_cmd="cd ~/exo && ./run.sh -c" - else - remote_cmd="cd ~/exo && ./run.sh -rc" - fi - - ( run_remote "$h" "$remote_cmd" ) || fail=1 & +for h in "${HOSTS[@]}"; do + ( + run_remote "$h" 'rm -f ~/.exo/*db* || true' + ) || fail=1 & done wait -(( fail == 0 )) && echo "🎉 Deployment finished!" || \ - { echo "⚠️ Some starts failed—see above."; exit 1; } +((fail == 0)) || { + echo "❌ Some hosts failed database cleanup." + exit 1 +} +echo "✓ Database files cleaned on all hosts." + +############################################################################### +# Phase 3 – start new exo processes in Terminal windows (parallel) +############################################################################### +echo "=== Stage 3: starting new exo processes ===" +fail=0 +for h in "${HOSTS[@]}"; do + # Use osascript to open Terminal windows on remote Mac + remote_cmd="osascript -e \"tell app \\\"Terminal\\\" to do script \\\"cd ~/exo; nix develop --command uv run exo\\\"\"" + + (run_remote "$h" "$remote_cmd") || fail=1 & +done +wait +((fail == 0)) && echo "🎉 Deployment finished!" || { + echo "⚠️ Some starts failed—see above." + exit 1 +} diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 00000000..1256dafb --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,15 @@ +# Generated by Cargo +# will have compiled files and executables +debug +target +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +# Generated by cargo mutants +# Contains mutation testing data +**/mutants.out*/ \ No newline at end of file diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 00000000..f45941f4 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,165 @@ +[workspace] +resolver = "3" +members = [ + "networking", + "exo_pyo3_bindings", + "system_custodian", + "util", +] + +[workspace.package] +version = "0.0.1" +edition = "2024" + +[profile.dev] +opt-level = 1 +debug = true + +[profile.release] +opt-level = 3 + +# Common shared dependendencies configured once at the workspace +# level, to be re-used more easily across workspace member crates. +# +# Common configurations include versions, paths, features, etc. +[workspace.dependencies] +## Crate members as common dependencies +networking = { path = "networking" } +system_custodian = { path = "system_custodian" } +util = { path = "util" } + +# Proc-macro authoring tools +syn = "2.0" +quote = "1.0" +proc-macro2 = "1.0" +darling = "0.20" + +# Macro dependecies +extend = "1.2" +delegate = "0.13" +impl-trait-for-tuples = "0.2" +clap = "4.5" +derive_more = { version = "2.0.1", features = ["display"] } +pin-project = "1" + +# Utility dependencies +itertools = "0.14" +thiserror = "2" +internment = "0.8" +recursion = "0.5" +regex = "1.11" +once_cell = "1.21" +thread_local = "1.1" +bon = "3.4" +generativity = "1.1" +anyhow = "1.0" +keccak-const = "0.2" + +# Functional generics/lenses frameworks +frunk_core = "0.4" +frunk = "0.4" +frunk_utils = "0.2" +frunk-enum-core = "0.3" + +# Async dependencies +tokio = "1.46" +futures = "0.3" +futures-util = "0.3" +futures-timer = "3.0" + +# Data structures +either = "1.15" +ordered-float = "5.0" +ahash = "0.8" + +# Tracing/logging +log = "0.4" + +# networking +libp2p = "0.56" +libp2p-tcp = "0.44" + +[workspace.lints.rust] +static_mut_refs = "warn" # Or use "warn" instead of deny +incomplete_features = "allow" + +# Clippy's lint category level configurations; +# every member crate needs to inherit these by adding +# +# ```toml +# [lints] +# workspace = true +# ``` +# +# to their `Cargo.toml` files +[workspace.lints.clippy] +# Clippy lint categories meant to be enabled all at once +correctness = { level = "deny", priority = -1 } +suspicious = { level = "warn", priority = -1 } +style = { level = "warn", priority = -1 } +complexity = { level = "warn", priority = -1 } +perf = { level = "warn", priority = -1 } +pedantic = { level = "warn", priority = -1 } +nursery = { level = "warn", priority = -1 } +cargo = { level = "warn", priority = -1 } + +# Individual Clippy lints from the `restriction` category +arithmetic_side_effects = "warn" +as_conversions = "warn" +assertions_on_result_states = "warn" +clone_on_ref_ptr = "warn" +decimal_literal_representation = "warn" +default_union_representation = "warn" +deref_by_slicing = "warn" +disallowed_script_idents = "deny" +else_if_without_else = "warn" +empty_enum_variants_with_brackets = "warn" +empty_structs_with_brackets = "warn" +error_impl_error = "warn" +exit = "deny" +expect_used = "warn" +float_cmp_const = "warn" +get_unwrap = "warn" +if_then_some_else_none = "warn" +impl_trait_in_params = "warn" +indexing_slicing = "warn" +infinite_loop = "warn" +let_underscore_must_use = "warn" +let_underscore_untyped = "warn" +lossy_float_literal = "warn" +mem_forget = "warn" +missing_inline_in_public_items = "warn" +multiple_inherent_impl = "warn" +multiple_unsafe_ops_per_block = "warn" +mutex_atomic = "warn" +non_zero_suggestions = "warn" +panic = "warn" +partial_pub_fields = "warn" +pattern_type_mismatch = "warn" +pub_without_shorthand = "warn" +rc_buffer = "warn" +rc_mutex = "warn" +redundant_type_annotations = "warn" +renamed_function_params = "warn" +rest_pat_in_fully_bound_structs = "warn" +same_name_method = "warn" +self_named_module_files = "deny" +semicolon_inside_block = "warn" +shadow_same = "warn" +shadow_unrelated = "warn" +str_to_string = "warn" +string_add = "warn" +string_lit_chars_any = "warn" +string_to_string = "warn" +tests_outside_test_module = "warn" +todo = "warn" +try_err = "warn" +undocumented_unsafe_blocks = "warn" +unnecessary_safety_comment = "warn" +unnecessary_safety_doc = "warn" +unneeded_field_pattern = "warn" +unseparated_literal_suffix = "warn" +unused_result_ok = "warn" +unused_trait_names = "warn" +unwrap_used = "warn" +verbose_file_reads = "warn" \ No newline at end of file diff --git a/rust/clippy.toml b/rust/clippy.toml new file mode 100644 index 00000000..6d5a6187 --- /dev/null +++ b/rust/clippy.toml @@ -0,0 +1,2 @@ +# we can manually exclude false-positive lint errors for dual packages (if in dependencies) +#allowed-duplicate-crates = ["hashbrown"] \ No newline at end of file diff --git a/rust/exo_pyo3_bindings/Cargo.toml b/rust/exo_pyo3_bindings/Cargo.toml new file mode 100644 index 00000000..4895ecf4 --- /dev/null +++ b/rust/exo_pyo3_bindings/Cargo.toml @@ -0,0 +1,77 @@ +[package] +name = "exo_pyo3_bindings" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +path = "src/lib.rs" +name = "exo_pyo3_bindings" + +# "cdylib" needed to produce shared library for Python to import +# "rlib" needed for stub-gen to run +crate-type = ["cdylib", "rlib"] + +[[bin]] +path = "src/bin/stub_gen.rs" +name = "stub_gen" +doc = false + +[lints] +workspace = true + +[dependencies] +networking = { workspace = true } + +# interop +pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!! + # "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11 + "nightly", # enables better-supported GIL integration + "experimental-async", # async support in #[pyfunction] & #[pymethods] + #"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation + #"py-clone", # adding Clone-ing of `Py` without GIL (may cause panics - remove if panics happen) + "multiple-pymethods", # allows multiple #[pymethods] sections per class + + # integrations with other libraries + "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational", + "ordered-float", "rust_decimal", "smallvec", + # "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde", +] } +pyo3-stub-gen = { version = "0.13.1" } +pyo3-async-runtimes = { version = "0.25", features = ["attributes", "tokio-runtime", "testing"] } + +# macro dependencies +extend = { workspace = true } +delegate = { workspace = true } +impl-trait-for-tuples = { workspace = true } +derive_more = { workspace = true } +pin-project = { workspace = true } + +# async runtime +tokio = { workspace = true, features = ["full", "tracing"] } +futures = { workspace = true } + +# utility dependencies +once_cell = "1.21.3" +thread_local = "1.1.9" +util = { workspace = true } +thiserror = { workspace = true } +#internment = { workspace = true } +#recursion = { workspace = true } +#generativity = { workspace = true } +#itertools = { workspace = true } + + +# Tracing +#tracing = "0.1" +#tracing-subscriber = "0.3" +#console-subscriber = "0.1.5" +#tracing-log = "0.2.0" +log = { workspace = true } +env_logger = "0.11" +pyo3-log = "0.12" + + +# Networking +libp2p = { workspace = true, features = ["full"] } diff --git a/rust/exo_pyo3_bindings/README.md b/rust/exo_pyo3_bindings/README.md new file mode 100644 index 00000000..e739dd89 --- /dev/null +++ b/rust/exo_pyo3_bindings/README.md @@ -0,0 +1 @@ +TODO: do something here.... diff --git a/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi b/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi new file mode 100644 index 00000000..cf2214cd --- /dev/null +++ b/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi @@ -0,0 +1,207 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401 + +import builtins +from enum import Enum + +class ConnectionUpdate: + @property + def update_type(self) -> ConnectionUpdateType: + r""" + Whether this is a connection or disconnection event + """ + @property + def peer_id(self) -> PeerId: + r""" + Identity of the peer that we have connected to or disconnected from. + """ + @property + def remote_ipv4(self) -> builtins.str: + r""" + Remote connection's IPv4 address. + """ + @property + def remote_tcp_port(self) -> builtins.int: + r""" + Remote connection's TCP port. + """ + +class Keypair: + r""" + Identity keypair of a node. + """ + @staticmethod + def generate_ed25519() -> Keypair: + r""" + Generate a new Ed25519 keypair. + """ + @staticmethod + def generate_ecdsa() -> Keypair: + r""" + Generate a new ECDSA keypair. + """ + @staticmethod + def generate_secp256k1() -> Keypair: + r""" + Generate a new Secp256k1 keypair. + """ + @staticmethod + def from_protobuf_encoding(bytes:bytes) -> Keypair: + r""" + Decode a private key from a protobuf structure and parse it as a `Keypair`. + """ + @staticmethod + def rsa_from_pkcs8(bytes:bytes) -> Keypair: + r""" + Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo` + format (i.e. unencrypted) as defined in [RFC5208]. + + [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5 + """ + @staticmethod + def secp256k1_from_der(bytes:bytes) -> Keypair: + r""" + Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey` + structure as defined in [RFC5915]. + + [RFC5915]: https://tools.ietf.org/html/rfc5915 + """ + @staticmethod + def ed25519_from_bytes(bytes:bytes) -> Keypair: ... + def to_protobuf_encoding(self) -> bytes: + r""" + Encode a private key as protobuf structure. + """ + def to_peer_id(self) -> PeerId: + r""" + Convert the `Keypair` into the corresponding `PeerId`. + """ + +class Multiaddr: + r""" + Representation of a Multiaddr. + """ + @staticmethod + def empty() -> Multiaddr: + r""" + Create a new, empty multiaddress. + """ + @staticmethod + def with_capacity(n:builtins.int) -> Multiaddr: + r""" + Create a new, empty multiaddress with the given capacity. + """ + @staticmethod + def from_bytes(bytes:bytes) -> Multiaddr: + r""" + Parse a `Multiaddr` value from its byte slice representation. + """ + @staticmethod + def from_string(string:builtins.str) -> Multiaddr: + r""" + Parse a `Multiaddr` value from its string representation. + """ + def len(self) -> builtins.int: + r""" + Return the length in bytes of this multiaddress. + """ + def is_empty(self) -> builtins.bool: + r""" + Returns true if the length of this multiaddress is 0. + """ + def to_bytes(self) -> bytes: + r""" + Return a copy of this [`Multiaddr`]'s byte representation. + """ + def to_string(self) -> builtins.str: + r""" + Convert a Multiaddr to a string. + """ + +class NetworkingHandle: + def __new__(cls, identity:Keypair) -> NetworkingHandle: ... + async def connection_update_recv(self) -> ConnectionUpdate: + r""" + Receives the next `ConnectionUpdate` from networking. + """ + async def connection_update_recv_many(self, limit:builtins.int) -> builtins.list[ConnectionUpdate]: + r""" + Receives at most `limit` `ConnectionUpdate`s from networking and returns them. + + For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately. + For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method + will sleep until a `ConnectionUpdate`s is sent. + """ + async def gossipsub_subscribe(self, topic:builtins.str) -> builtins.bool: + r""" + Subscribe to a `GossipSub` topic. + + Returns `True` if the subscription worked. Returns `False` if we were already subscribed. + """ + async def gossipsub_unsubscribe(self, topic:builtins.str) -> builtins.bool: + r""" + Unsubscribes from a `GossipSub` topic. + + Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed. + """ + async def gossipsub_publish(self, topic:builtins.str, data:bytes) -> None: + r""" + Publishes a message with multiple topics to the `GossipSub` network. + + If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception. + """ + async def gossipsub_recv(self) -> tuple[builtins.str, bytes]: + r""" + Receives the next message from the `GossipSub` network. + """ + async def gossipsub_recv_many(self, limit:builtins.int) -> builtins.list[tuple[builtins.str, bytes]]: + r""" + Receives at most `limit` messages from the `GossipSub` network and returns them. + + For `limit = 0`, an empty collection of messages will be returned immediately. + For `limit > 0`, if there are no messages in the channel's queue this method + will sleep until a message is sent. + """ + +class NoPeersSubscribedToTopicError(builtins.Exception): + def __new__(cls, *args) -> NoPeersSubscribedToTopicError: ... + def __repr__(self) -> builtins.str: ... + def __str__(self) -> builtins.str: ... + +class PeerId: + r""" + Identifier of a peer of the network. + + The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer + as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md). + """ + @staticmethod + def random() -> PeerId: + r""" + Generates a random peer ID from a cryptographically secure PRNG. + + This is useful for randomly walking on a DHT, or for testing purposes. + """ + @staticmethod + def from_bytes(bytes:bytes) -> PeerId: + r""" + Parses a `PeerId` from bytes. + """ + def to_bytes(self) -> bytes: + r""" + Returns a raw bytes representation of this `PeerId`. + """ + def to_base58(self) -> builtins.str: + r""" + Returns a base-58 encoded string of this `PeerId`. + """ + def __repr__(self) -> builtins.str: ... + def __str__(self) -> builtins.str: ... + +class ConnectionUpdateType(Enum): + r""" + Connection or disconnection event discriminant type. + """ + Connected = ... + Disconnected = ... + diff --git a/rust/exo_pyo3_bindings/pyproject.toml b/rust/exo_pyo3_bindings/pyproject.toml new file mode 100644 index 00000000..f1d24cf9 --- /dev/null +++ b/rust/exo_pyo3_bindings/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[project] +name = "exo_pyo3_bindings" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +authors = [ + { name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" } +] +requires-python = ">=3.13" +dependencies = [] + +[dependency-groups] +dev = [ + "exo_pyo3_bindings", + "pytest>=8.4.0", + "pytest-asyncio>=1.0.0", +] + +[tool.maturin] +#purelib = true +#python-source = "python" +module-name = "exo_pyo3_bindings" +features = ["pyo3/extension-module", "pyo3/experimental-async"] + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" +asyncio_mode = "auto" \ No newline at end of file diff --git a/rust/exo_pyo3_bindings/src/allow_threading.rs b/rust/exo_pyo3_bindings/src/allow_threading.rs new file mode 100644 index 00000000..3106e535 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/allow_threading.rs @@ -0,0 +1,40 @@ +//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await +//! + +use pin_project::pin_project; +use pyo3::marker::Ungil; +use pyo3::prelude::*; +use std::{ + future::Future, + pin::{Pin, pin}, + task::{Context, Poll}, +}; + +/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await +#[pin_project] +#[repr(transparent)] +pub(crate) struct AllowThreads(#[pin] F); + +impl AllowThreads +where + Self: Future, +{ + pub fn new(f: F) -> Self { + Self(f) + } +} + +impl Future for AllowThreads +where + F: Future + Ungil, + F::Output: Ungil, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let waker = cx.waker(); + Python::with_gil(|py| { + py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker))) + }) + } +} diff --git a/rust/exo_pyo3_bindings/src/bin/stub_gen.rs b/rust/exo_pyo3_bindings/src/bin/stub_gen.rs new file mode 100644 index 00000000..3e30f493 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/bin/stub_gen.rs @@ -0,0 +1,8 @@ +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init(); + let stub = exo_pyo3_bindings::stub_info()?; + stub.generate()?; + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/src/examples/mod.rs b/rust/exo_pyo3_bindings/src/examples/mod.rs new file mode 100644 index 00000000..bde14199 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/examples/mod.rs @@ -0,0 +1,240 @@ +//! This module exists to hold examples of some pyo3 patterns that may be too complex to +//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around. +//! +//! Pattern examples include: +//! - Async task handles: with GC-integrated cleanup +//! - Sync/async callbacks from python: with propper eventloop handling +//! +//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint +//! - Store mutable fields in tokio's `Mutex` +//! - For async code: take `&self` and `.lock().await` +//! - For sync code: take `&mut self` and `.get_mut()` + +use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _}; +use futures::FutureExt as _; +use futures::future::BoxFuture; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::{PyModule, PyModuleMethods as _}; +use pyo3::{ + Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods, +}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TryRecvError; + +fn needs_tokio_runtime() { + tokio::runtime::Handle::current(); +} + +type SyncCallback = Box; +type AsyncCallback = Box BoxFuture<'static, ()> + Send + Sync>; + +enum AsyncTaskMessage { + SyncCallback(SyncCallback), + AsyncCallback(AsyncCallback), +} + +async fn async_task( + sender: mpsc::UnboundedSender<()>, + mut receiver: mpsc::UnboundedReceiver, +) { + log::info!("RUST: async task started"); + + // task state + let mut interval = tokio::time::interval(Duration::from_secs(1)); + + let mut sync_cbs: Vec = vec![]; + let mut async_cbs: Vec = vec![]; + + loop { + tokio::select! { + // handle incoming messages from task-handle + message = receiver.recv() => { + // handle closed channel by exiting + let Some(message) = message else { + log::info!("RUST: channel closed"); + break; + }; + + // dispatch incoming event + match message { + AsyncTaskMessage::SyncCallback(cb) => { + sync_cbs.push(cb); + } + AsyncTaskMessage::AsyncCallback(cb) => { + async_cbs.push(cb); + } + } + } + + // handle all other events + _ = interval.tick() => { + log::info!("RUST: async task tick"); + + // call back all sync callbacks + for cb in &sync_cbs { + cb(); + } + + // call back all async callbacks + for cb in &async_cbs { + cb().await; + } + + // send event on unbounded channel + sender.send(()).expect("handle receiver cannot be closed/dropped"); + } + } + } + + log::info!("RUST: async task stopped"); +} + +// #[gen_stub_pyclass] +#[pyclass(name = "AsyncTaskHandle")] +#[derive(Debug)] +struct PyAsyncTaskHandle { + sender: Option>, + receiver: mpsc::UnboundedReceiver<()>, +} + +#[allow(clippy::expect_used)] +impl PyAsyncTaskHandle { + const fn sender(&self) -> &mpsc::UnboundedSender { + self.sender + .as_ref() + .expect("The sender should only be None after de-initialization.") + } + + const fn sender_mut(&mut self) -> &mpsc::UnboundedSender { + self.sender + .as_mut() + .expect("The sender should only be None after de-initialization.") + } + + const fn new( + sender: mpsc::UnboundedSender, + receiver: mpsc::UnboundedReceiver<()>, + ) -> Self { + Self { + sender: Some(sender), + receiver, + } + } +} + +// #[gen_stub_pymethods] +#[pymethods] +impl PyAsyncTaskHandle { + #[new] + fn py_new(py: Python<'_>) -> PyResult { + use pyo3_async_runtimes::tokio::get_runtime; + + // create communication channel TOWARDS our task + let (h_sender, t_receiver) = mpsc::unbounded_channel::(); + + // create communication channel FROM our task + let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>(); + + // perform necessary setup within tokio context - or it crashes + let () = get_runtime().block_on(async { needs_tokio_runtime() }); + + // spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!! + _ = get_runtime().spawn_with_scope(py, async move { + async_task(t_sender, t_receiver).await; + }); + Ok(Self::new(h_sender, h_receiver)) + } + + /// NOTE: exceptions in callbacks are silently ignored until end of execution + fn add_sync_callback( + &self, + // #[gen_stub(override_type( + // type_repr="collections.abc.Callable[[], None]", + // imports=("collections.abc") + // ))] + callback: Py, + ) -> PyResult<()> { + // blocking call to async method -> can do non-blocking if needed + self.sender() + .send(AsyncTaskMessage::SyncCallback(Box::new(move || { + _ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py)); + }))) + .pyerr()?; + Ok(()) + } + + /// NOTE: exceptions in callbacks are silently ignored until end of execution + fn add_async_callback( + &self, + // #[gen_stub(override_type( + // type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]", + // imports=("collections.abc") + // ))] + callback: Py, + ) -> PyResult<()> { + // blocking call to async method -> can do non-blocking if needed + self.sender() + .send(AsyncTaskMessage::AsyncCallback(Box::new(move || { + let c = Python::with_gil(|py| callback.clone_ref(py)); + async move { + if let Some(f) = Python::with_gil(|py| { + let coroutine = c.call0(py).write_unraisable_with(py)?; + pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py)) + .write_unraisable_with(py) + }) { + _ = f.await.write_unraisable(); + } + } + .boxed() + }))) + .pyerr()?; + Ok(()) + } + + async fn receive_unit(&mut self) -> PyResult<()> { + self.receiver + .recv() + .await + .ok_or(PyErr::new::( + "cannot receive unit on closed channel", + )) + } + + fn drain_units(&mut self) -> PyResult { + let mut cnt = 0; + loop { + match self.receiver.try_recv() { + Err(TryRecvError::Disconnected) => { + return Err(PyErr::new::( + "cannot receive unit on closed channel", + )); + } + Err(TryRecvError::Empty) => return Ok(cnt), + Ok(()) => { + cnt += 1; + continue; + } + } + } + } + + // #[gen_stub(skip)] + const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + Ok(()) // This is needed purely so `__clear__` can work + } + + // #[gen_stub(skip)] + fn __clear__(&mut self) { + // TODO: may or may not need to await a "kill-signal" oneshot channel message, + // to ensure that the networking task is done BEFORE exiting the clear function... + // but this may require GIL?? and it may not be safe to call GIL here?? + self.sender = None; // Using Option as a trick to force `sender` channel to be dropped + } +} + +pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/src/lib.rs b/rust/exo_pyo3_bindings/src/lib.rs new file mode 100644 index 00000000..4f591b8c --- /dev/null +++ b/rust/exo_pyo3_bindings/src/lib.rs @@ -0,0 +1,217 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation +//! +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +#![feature(tuple_trait)] +#![feature(unboxed_closures)] +// #![feature(stmt_expr_attributes)] +// #![feature(assert_matches)] +// #![feature(async_fn_in_dyn_trait)] +// #![feature(async_for_loop)] +// #![feature(auto_traits)] +// #![feature(negative_impls)] + +extern crate core; +mod allow_threading; +mod examples; +pub(crate) mod networking; +pub(crate) mod pylibp2p; + +use crate::networking::networking_submodule; +use crate::pylibp2p::ident::ident_submodule; +use crate::pylibp2p::multiaddr::multiaddr_submodule; +use pyo3::prelude::PyModule; +use pyo3::prelude::*; +use pyo3::{Bound, PyResult, pyclass, pymodule}; +use pyo3_stub_gen::define_stub_info_gatherer; + +/// Namespace for all the constants used by this crate. +pub(crate) mod r#const { + pub const MPSC_CHANNEL_SIZE: usize = 1024; +} + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias { + use std::error::Error; + use std::marker::Tuple; + + pub trait SendFn = + Fn + Send + 'static; + + pub type AnyError = Box; + pub type AnyResult = Result; +} + +/// Namespace for crate-wide extension traits/methods +pub(crate) mod ext { + use crate::allow_threading::AllowThreads; + use extend::ext; + use pyo3::exceptions::{PyConnectionError, PyRuntimeError}; + use pyo3::marker::Ungil; + use pyo3::types::PyBytes; + use pyo3::{Py, PyErr, PyResult, Python}; + use tokio::runtime::Runtime; + use tokio::sync::mpsc; + use tokio::sync::mpsc::error::TryRecvError; + use tokio::task::JoinHandle; + + #[ext(pub, name = ByteArrayExt)] + impl [u8] { + fn pybytes(&self) -> Py { + Python::with_gil(|py| PyBytes::new(py, self).unbind()) + } + } + + #[ext(pub, name = ResultExt)] + impl Result + where + E: ToString, + { + fn pyerr(self) -> PyResult { + self.map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + } + + pub trait FutureExt: Future + Sized { + /// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await + fn allow_threads_py(self) -> AllowThreads + where + AllowThreads: Future, + { + AllowThreads::new(self) + } + } + + impl FutureExt for T {} + + #[ext(pub, name = PyErrExt)] + impl PyErr { + fn receiver_channel_closed() -> Self { + PyConnectionError::new_err("Receiver channel closed unexpectedly") + } + } + + #[ext(pub, name = PyResultExt)] + impl PyResult { + fn write_unraisable(self) -> Option { + Python::with_gil(|py| self.write_unraisable_with(py)) + } + + fn write_unraisable_with(self, py: Python<'_>) -> Option { + match self { + Ok(v) => Some(v), + Err(e) => { + // write error back to python + e.write_unraisable(py, None); + None + } + } + } + } + + #[ext(pub, name = TokioRuntimeExt)] + impl Runtime { + fn spawn_with_scope(&self, py: Python<'_>, future: F) -> PyResult> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?; + Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future))) + } + } + + #[ext(pub, name = TokioMpscSenderExt)] + impl mpsc::Sender { + /// Sends a value, waiting until there is capacity. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been closed. + async fn send_py(&self, value: T) -> PyResult<()> { + self.send(value) + .await + .map_err(|_| PyErr::receiver_channel_closed()) + } + } + + #[ext(pub, name = TokioMpscReceiverExt)] + impl mpsc::Receiver { + /// Receives the next value for this receiver. + async fn recv_py(&mut self) -> PyResult { + self.recv().await.ok_or_else(PyErr::receiver_channel_closed) + } + + /// Receives at most `limit` values for this receiver and returns them. + /// + /// For `limit = 0`, an empty collection of messages will be returned immediately. + /// For `limit > 0`, if there are no messages in the channel's queue this method + /// will sleep until a message is sent. + async fn recv_many_py(&mut self, limit: usize) -> PyResult> { + // get updates from receiver channel + let mut updates = Vec::with_capacity(limit); + let received = self.recv_many(&mut updates, limit).await; + + // if we received zero items, then the channel was unexpectedly closed + if limit != 0 && received == 0 { + return Err(PyErr::receiver_channel_closed()); + } + + Ok(updates) + } + + /// Tries to receive the next value for this receiver. + fn try_recv_py(&mut self) -> PyResult> { + match self.try_recv() { + Ok(v) => Ok(Some(v)), + Err(TryRecvError::Empty) => Ok(None), + Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()), + } + } + } +} + +pub(crate) mod private { + use std::marker::Sized; + + /// Sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} + +/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`]. +#[repr(transparent)] +pub(crate) struct ClonePy(pub Py); + +impl Clone for ClonePy { + fn clone(&self) -> Self { + Python::with_gil(|py| Self(self.0.clone_ref(py))) + } +} + +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule(name = "exo_pyo3_bindings")] +fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + // install logger + pyo3_log::init(); + + // TODO: for now this is all NOT a submodule, but figure out how to make the submodule system + // work with maturin, where the types generate correctly, in the right folder, without + // too many importing issues... + ident_submodule(m)?; + multiaddr_submodule(m)?; + networking_submodule(m)?; + + // top-level constructs + // TODO: ... + + Ok(()) +} + +define_stub_info_gatherer!(stub_info); diff --git a/rust/exo_pyo3_bindings/src/networking.rs b/rust/exo_pyo3_bindings/src/networking.rs new file mode 100644 index 00000000..021fc90e --- /dev/null +++ b/rust/exo_pyo3_bindings/src/networking.rs @@ -0,0 +1,534 @@ +#![allow( + clippy::multiple_inherent_impl, + clippy::unnecessary_wraps, + clippy::unused_self, + clippy::needless_pass_by_value +)] + +use crate::r#const::MPSC_CHANNEL_SIZE; +use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _}; +use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _}; +use crate::pyclass; +use crate::pylibp2p::ident::{PyKeypair, PyPeerId}; +use libp2p::futures::StreamExt as _; +use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError}; +use libp2p::swarm::SwarmEvent; +use libp2p::{gossipsub, mdns}; +use pyo3::prelude::{PyModule, PyModuleMethods as _}; +use pyo3::types::PyBytes; +use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods}; +use std::net::IpAddr; +use tokio::sync::{Mutex, mpsc, oneshot}; +use networking::discovery; +use networking::swarm::create_swarm; +use util::ext::VecExt as _; + +mod exception { + use pyo3::{exceptions::{PyException}, prelude::*, PyErrArguments}; + use pyo3::types::PyTuple; + use pyo3_stub_gen::{derive::*}; + + + #[gen_stub_pyclass] + #[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")] + pub struct PyNoPeersSubscribedToTopicError {} + + impl PyNoPeersSubscribedToTopicError { + const MSG: &'static str = "\ + No peers are currently subscribed to receive messages on this topic. \ + Wait for peers to subscribe or check your network connectivity."; + + /// Creates a new [ `PyErr` ] of this type. + /// + /// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3" + pub(crate) fn new_err() -> PyErr { + PyErr::new::(()) // TODO: check if this needs to be replaced??? + } + } + + #[gen_stub_pymethods] + #[pymethods] + impl PyNoPeersSubscribedToTopicError { + #[new] + #[pyo3(signature = (*args))] + #[allow(unused_variables)] + pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self { + Self {} + } + + fn __repr__(&self) -> String { + format!("PeerId(\"{}\")", Self::MSG) + } + + fn __str__(&self) -> String { + Self::MSG.to_string() + } + } +} + +/// Connection or disconnection event discriminant type. +#[gen_stub_pyclass_enum] +#[pyclass(eq, eq_int, name = "ConnectionUpdateType")] +#[derive(Debug, Clone, PartialEq)] +enum PyConnectionUpdateType { + Connected = 0, + Disconnected, +} + +#[gen_stub_pyclass] +#[pyclass(frozen, name = "ConnectionUpdate")] +#[derive(Debug, Clone)] +struct PyConnectionUpdate { + /// Whether this is a connection or disconnection event + #[pyo3(get)] + update_type: PyConnectionUpdateType, + + /// Identity of the peer that we have connected to or disconnected from. + #[pyo3(get)] + peer_id: PyPeerId, + + /// Remote connection's IPv4 address. + #[pyo3(get)] + remote_ipv4: String, + + /// Remote connection's TCP port. + #[pyo3(get)] + remote_tcp_port: u16, +} + +enum ToTask { + GossipsubSubscribe { + topic: String, + result_tx: oneshot::Sender>, + }, + GossipsubUnsubscribe { + topic: String, + result_tx: oneshot::Sender, + }, + GossipsubPublish { + topic: String, + data: Vec, + result_tx: oneshot::Sender>, + }, +} + +#[allow(clippy::enum_glob_use)] +async fn networking_task( + mut swarm: networking::swarm::Swarm, + mut to_task_rx: mpsc::Receiver, + connection_update_tx: mpsc::Sender, + gossipsub_message_tx: mpsc::Sender<(String, Vec)>, +) { + use networking::swarm::BehaviourEvent::*; + use SwarmEvent::*; + use ToTask::*; + use mdns::Event::*; + + log::info!("RUST: networking task started"); + + loop { + tokio::select! { + message = to_task_rx.recv() => { + // handle closed channel + let Some(message) = message else { + log::info!("RUST: channel closed"); + break; + }; + + // dispatch incoming messages + match message { + GossipsubSubscribe { topic, result_tx } => { + // try to subscribe + let result = swarm.behaviour_mut() + .gossipsub.subscribe(&IdentTopic::new(topic)); + + // send response oneshot + if let Err(e) = result_tx.send(result.pyerr()) { + log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}"); + continue; + } + } + GossipsubUnsubscribe { topic, result_tx } => { + // try to unsubscribe from the topic + let result = swarm.behaviour_mut() + .gossipsub.unsubscribe(&IdentTopic::new(topic)); + + // send response oneshot (or exit if connection closed) + if let Err(e) = result_tx.send(result) { + log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}"); + continue; + } + } + GossipsubPublish { topic, data, result_tx } => { + // try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception + let result = swarm.behaviour_mut().gossipsub.publish( + IdentTopic::new(topic), data); + let pyresult: PyResult = if let Err(PublishError::NoPeersSubscribedToTopic) = result { + Err(exception::PyNoPeersSubscribedToTopicError::new_err()) + } else { + result.pyerr() + }; + + // send response oneshot (or exit if connection closed) + if let Err(e) = result_tx.send(pyresult) { + log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}"); + continue; + } + } + } + } + + // architectural solution to this problem: + // create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired) + // -> it will emmit TRUE connected/disconnected events consumable elsewhere + // + // gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list + // then for actual communication it will dial those peers if need-be + swarm_event = swarm.select_next_some() => { + match swarm_event { + Behaviour(Gossipsub(gossipsub::Event::Message { + message: Message { + topic, + data, + .. + }, + .. + })) => { + // topic-ID is just the topic hash!!! (since we used identity hasher) + let message = (topic.into_string(), data); + + // send incoming message to channel (or exit if connection closed) + if let Err(e) = gossipsub_message_tx.send(message).await { + log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}"); + continue; + } + }, + Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => { + // grab IPv4 string + let remote_ipv4 = match remote_ip { + IpAddr::V4(ip) => ip.to_string(), + IpAddr::V6(ip) => { + log::warn!("RUST: ignoring connection to IPv6 address: {ip}"); + continue; + } + }; + + // send connection event to channel (or exit if connection closed) + if let Err(e) = connection_update_tx.send(PyConnectionUpdate { + update_type: PyConnectionUpdateType::Connected, + peer_id: PyPeerId(peer_id), + remote_ipv4, + remote_tcp_port, + }).await { + log::error!("RUST: could not send connection update since channel already closed: {e}"); + continue; + } + }, + Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => { + // grab IPv4 string + let remote_ipv4 = match remote_ip { + IpAddr::V4(ip) => ip.to_string(), + IpAddr::V6(ip) => { + log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}"); + continue; + } + }; + + // send disconnection event to channel (or exit if connection closed) + if let Err(e) = connection_update_tx.send(PyConnectionUpdate { + update_type: PyConnectionUpdateType::Disconnected, + peer_id: PyPeerId(peer_id), + remote_ipv4, + remote_tcp_port, + }).await { + log::error!("RUST: could not send connection update since channel already closed: {e}"); + continue; + } + }, + e => { + log::info!("RUST: other event {e:?}"); + } + } + } + } + } + + log::info!("RUST: networking task stopped"); +} + +#[gen_stub_pyclass] +#[pyclass(name = "NetworkingHandle")] +#[derive(Debug)] +struct PyNetworkingHandle { + // channels + to_task_tx: Option>, + connection_update_rx: Mutex>, + gossipsub_message_rx: Mutex)>>, +} + +impl Drop for PyNetworkingHandle { + fn drop(&mut self) { + // TODO: may or may not need to await a "kill-signal" oneshot channel message, + // to ensure that the networking task is done BEFORE exiting the clear function... + // but this may require GIL?? and it may not be safe to call GIL here?? + self.to_task_tx = None; // Using Option as a trick to force channel to be dropped + } +} + +#[allow(clippy::expect_used)] +impl PyNetworkingHandle { + fn new( + to_task_tx: mpsc::Sender, + connection_update_rx: mpsc::Receiver, + gossipsub_message_rx: mpsc::Receiver<(String, Vec)>, + ) -> Self { + Self { + to_task_tx: Some(to_task_tx), + connection_update_rx: Mutex::new(connection_update_rx), + gossipsub_message_rx: Mutex::new(gossipsub_message_rx), + } + } + + const fn to_task_tx(&self) -> &mpsc::Sender { + self.to_task_tx + .as_ref() + .expect("The sender should only be None after de-initialization.") + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyNetworkingHandle { + // NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()` + // immediately beforehand to release the interpreter. + // SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await + + // ---- Lifecycle management methods ---- + + #[new] + fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult { + use pyo3_async_runtimes::tokio::get_runtime; + + // create communication channels + let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE); + let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE); + let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE); + + // get identity + let identity = identity.borrow().0.clone(); + + // create networking swarm (within tokio context!! or it crashes) + let swarm = get_runtime() + .block_on(async { create_swarm(identity) }) + .pyerr()?; + + // spawn tokio task running the networking logic + get_runtime().spawn(async move { + networking_task( + swarm, + to_task_rx, + connection_update_tx, + gossipsub_message_tx, + ) + .await; + }); + Ok(Self::new( + to_task_tx, + connection_update_rx, + gossipsub_message_rx, + )) + } + + #[gen_stub(skip)] + const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + Ok(()) // This is needed purely so `__clear__` can work + } + + #[gen_stub(skip)] + fn __clear__(&mut self) { + // TODO: may or may not need to await a "kill-signal" oneshot channel message, + // to ensure that the networking task is done BEFORE exiting the clear function... + // but this may require GIL?? and it may not be safe to call GIL here?? + self.to_task_tx = None; // Using Option as a trick to force channel to be dropped + } + + // ---- Connection update receiver methods ---- + + /// Receives the next `ConnectionUpdate` from networking. + async fn connection_update_recv(&self) -> PyResult { + self.connection_update_rx + .lock() + .allow_threads_py() // allow-threads-aware async call + .await + .recv_py() + .allow_threads_py() // allow-threads-aware async call + .await + } + + /// Receives at most `limit` `ConnectionUpdate`s from networking and returns them. + /// + /// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately. + /// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method + /// will sleep until a `ConnectionUpdate`s is sent. + async fn connection_update_recv_many(&self, limit: usize) -> PyResult> { + self.connection_update_rx + .lock() + .allow_threads_py() // allow-threads-aware async call + .await + .recv_many_py(limit) + .allow_threads_py() // allow-threads-aware async call + .await + } + + // TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex) + // so its too dangerous to expose just yet. figure out a better semantics for handling this, + // so things don't randomly block + // /// Tries to receive the next `ConnectionUpdate` from networking. + // fn connection_update_try_recv(&self) -> PyResult> { + // self.connection_update_rx.blocking_lock().try_recv_py() + // } + // + // /// Checks if the `ConnectionUpdate` channel is empty. + // fn connection_update_is_empty(&self) -> bool { + // self.connection_update_rx.blocking_lock().is_empty() + // } + // + // /// Returns the number of `ConnectionUpdate`s in the channel. + // fn connection_update_len(&self) -> usize { + // self.connection_update_rx.blocking_lock().len() + // } + + // ---- Gossipsub management methods ---- + + /// Subscribe to a `GossipSub` topic. + /// + /// Returns `True` if the subscription worked. Returns `False` if we were already subscribed. + async fn gossipsub_subscribe(&self, topic: String) -> PyResult { + let (tx, rx) = oneshot::channel(); + + // send off request to subscribe + self.to_task_tx() + .send_py(ToTask::GossipsubSubscribe { + topic, + result_tx: tx, + }) + .allow_threads_py() // allow-threads-aware async call + .await?; + + // wait for response & return any errors + rx.allow_threads_py() // allow-threads-aware async call + .await + .map_err(|_| PyErr::receiver_channel_closed())? + } + + /// Unsubscribes from a `GossipSub` topic. + /// + /// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed. + async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult { + let (tx, rx) = oneshot::channel(); + + // send off request to unsubscribe + self.to_task_tx() + .send_py(ToTask::GossipsubUnsubscribe { + topic, + result_tx: tx, + }) + .allow_threads_py() // allow-threads-aware async call + .await?; + + // wait for response & convert any errors + rx.allow_threads_py() // allow-threads-aware async call + .await + .map_err(|_| PyErr::receiver_channel_closed()) + } + + /// Publishes a message with multiple topics to the `GossipSub` network. + /// + /// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception. + async fn gossipsub_publish(&self, topic: String, data: Py) -> PyResult<()> { + let (tx, rx) = oneshot::channel(); + + // send off request to subscribe + let data = Python::with_gil(|py| Vec::from(data.as_bytes(py))); + self.to_task_tx() + .send_py(ToTask::GossipsubPublish { + topic, + data, + result_tx: tx, + }) + .allow_threads_py() // allow-threads-aware async call + .await?; + + // wait for response & return any errors => ignore messageID for now!!! + let _ = rx + .allow_threads_py() // allow-threads-aware async call + .await + .map_err(|_| PyErr::receiver_channel_closed())??; + Ok(()) + } + + // ---- Gossipsub message receiver methods ---- + + /// Receives the next message from the `GossipSub` network. + async fn gossipsub_recv(&self) -> PyResult<(String, Py)> { + self.gossipsub_message_rx + .lock() + .allow_threads_py() // allow-threads-aware async call + .await + .recv_py() + .allow_threads_py() // allow-threads-aware async call + .await + .map(|(t, d)| (t, d.pybytes())) + } + + /// Receives at most `limit` messages from the `GossipSub` network and returns them. + /// + /// For `limit = 0`, an empty collection of messages will be returned immediately. + /// For `limit > 0`, if there are no messages in the channel's queue this method + /// will sleep until a message is sent. + async fn gossipsub_recv_many(&self, limit: usize) -> PyResult)>> { + Ok(self + .gossipsub_message_rx + .lock() + .allow_threads_py() // allow-threads-aware async call + .await + .recv_many_py(limit) + .allow_threads_py() // allow-threads-aware async call + .await? + .map(|(t, d)| (t, d.pybytes()))) + } + + // TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex) + // so its too dangerous to expose just yet. figure out a better semantics for handling this, + // so things don't randomly block + // /// Tries to receive the next message from the `GossipSub` network. + // fn gossipsub_try_recv(&self) -> PyResult)>> { + // Ok(self + // .gossipsub_message_rx + // .blocking_lock() + // .try_recv_py()? + // .map(|(t, d)| (t, d.pybytes()))) + // } + // + // /// Checks if the `GossipSub` message channel is empty. + // fn gossipsub_is_empty(&self) -> bool { + // self.gossipsub_message_rx.blocking_lock().is_empty() + // } + // + // /// Returns the number of `GossipSub` messages in the channel. + // fn gossipsub_len(&self) -> usize { + // self.gossipsub_message_rx.blocking_lock().len() + // } +} + +pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/src/pylibp2p/ident.rs b/rust/exo_pyo3_bindings/src/pylibp2p/ident.rs new file mode 100644 index 00000000..3c27526a --- /dev/null +++ b/rust/exo_pyo3_bindings/src/pylibp2p/ident.rs @@ -0,0 +1,159 @@ +use crate::ext::ResultExt as _; +use libp2p::PeerId; +use libp2p::identity::Keypair; +use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _}; +use pyo3::types::PyBytes; +use pyo3::{Bound, PyResult, Python, pyclass, pymethods}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +/// Identity keypair of a node. +#[gen_stub_pyclass] +#[pyclass(name = "Keypair", frozen)] +#[repr(transparent)] +pub struct PyKeypair(pub Keypair); + +#[gen_stub_pymethods] +#[pymethods] +#[allow(clippy::needless_pass_by_value)] +impl PyKeypair { + /// Generate a new Ed25519 keypair. + #[staticmethod] + fn generate_ed25519() -> Self { + Self(Keypair::generate_ed25519()) + } + + /// Generate a new ECDSA keypair. + #[staticmethod] + fn generate_ecdsa() -> Self { + Self(Keypair::generate_ecdsa()) + } + + /// Generate a new Secp256k1 keypair. + #[staticmethod] + fn generate_secp256k1() -> Self { + Self(Keypair::generate_secp256k1()) + } + + /// Decode a private key from a protobuf structure and parse it as a `Keypair`. + #[staticmethod] + fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult { + let bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?)) + } + + /// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo` + /// format (i.e. unencrypted) as defined in [RFC5208]. + /// + /// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5 + #[staticmethod] + fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult { + let mut bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?)) + } + + /// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey` + /// structure as defined in [RFC5915]. + /// + /// [RFC5915]: https://tools.ietf.org/html/rfc5915 + #[staticmethod] + fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult { + let mut bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?)) + } + + #[staticmethod] + fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult { + let mut bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?)) + } + + /// Encode a private key as protobuf structure. + fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult> { + let bytes = self.0.to_protobuf_encoding().pyerr()?; + Ok(PyBytes::new(py, &bytes)) + } + + /// Convert the `Keypair` into the corresponding `PeerId`. + fn to_peer_id(&self) -> PyPeerId { + PyPeerId(self.0.public().to_peer_id()) + } + + // /// Hidden constructor for pickling support. TODO: figure out how to do pickling... + // #[gen_stub(skip)] + // #[new] + // fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult { + // Self::from_protobuf_encoding(bytes) + // } + // + // #[gen_stub(skip)] + // fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> { + // *self = Self::from_protobuf_encoding(state)?; + // Ok(()) + // } + // + // #[gen_stub(skip)] + // fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult> { + // self.to_protobuf_encoding(py) + // } + // + // #[gen_stub(skip)] + // pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> { + // Ok((self.to_protobuf_encoding(py)?,)) + // } +} + +/// Identifier of a peer of the network. +/// +/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer +/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md). +#[gen_stub_pyclass] +#[pyclass(name = "PeerId", frozen)] +#[derive(Debug, Clone)] +#[repr(transparent)] +pub struct PyPeerId(pub PeerId); + +#[gen_stub_pymethods] +#[pymethods] +#[allow(clippy::needless_pass_by_value)] +impl PyPeerId { + /// Generates a random peer ID from a cryptographically secure PRNG. + /// + /// This is useful for randomly walking on a DHT, or for testing purposes. + #[staticmethod] + fn random() -> Self { + Self(PeerId::random()) + } + + /// Parses a `PeerId` from bytes. + #[staticmethod] + fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult { + let bytes = Vec::from(bytes.as_bytes()); + Ok(Self(PeerId::from_bytes(&bytes).pyerr()?)) + } + + /// Returns a raw bytes representation of this `PeerId`. + fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + let bytes = self.0.to_bytes(); + PyBytes::new(py, &bytes) + } + + /// Returns a base-58 encoded string of this `PeerId`. + fn to_base58(&self) -> String { + self.0.to_base58() + } + + fn __repr__(&self) -> String { + format!("PeerId({})", self.to_base58()) + } + + fn __str__(&self) -> String { + self.to_base58() + } +} + +pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/src/pylibp2p/mod.rs b/rust/exo_pyo3_bindings/src/pylibp2p/mod.rs new file mode 100644 index 00000000..8eb1bdc0 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/pylibp2p/mod.rs @@ -0,0 +1,8 @@ +//! A module for exposing Rust's libp2p datatypes over Pyo3 +//! +//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own +//! independent identity type of some kind or another. This may require handshaking. +//! + +pub mod ident; +pub mod multiaddr; diff --git a/rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs b/rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs new file mode 100644 index 00000000..4d398b53 --- /dev/null +++ b/rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs @@ -0,0 +1,81 @@ +use crate::ext::ResultExt as _; +use libp2p::Multiaddr; +use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _}; +use pyo3::types::PyBytes; +use pyo3::{Bound, PyResult, Python, pyclass, pymethods}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +use std::str::FromStr as _; + +/// Representation of a Multiaddr. +#[gen_stub_pyclass] +#[pyclass(name = "Multiaddr", frozen)] +#[derive(Debug, Clone)] +#[repr(transparent)] +pub struct PyMultiaddr(pub Multiaddr); + +#[gen_stub_pymethods] +#[pymethods] +#[allow(clippy::needless_pass_by_value)] +impl PyMultiaddr { + /// Create a new, empty multiaddress. + #[staticmethod] + fn empty() -> Self { + Self(Multiaddr::empty()) + } + + /// Create a new, empty multiaddress with the given capacity. + #[staticmethod] + fn with_capacity(n: usize) -> Self { + Self(Multiaddr::with_capacity(n)) + } + + /// Parse a `Multiaddr` value from its byte slice representation. + #[staticmethod] + fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult { + let bytes = Vec::from(bytes.as_bytes()); + Ok(Self(Multiaddr::try_from(bytes).pyerr()?)) + } + + /// Parse a `Multiaddr` value from its string representation. + #[staticmethod] + fn from_string(string: String) -> PyResult { + Ok(Self(Multiaddr::from_str(&string).pyerr()?)) + } + + /// Return the length in bytes of this multiaddress. + fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if the length of this multiaddress is 0. + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Return a copy of this [`Multiaddr`]'s byte representation. + fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + let bytes = self.0.to_vec(); + PyBytes::new(py, &bytes) + } + + /// Convert a Multiaddr to a string. + fn to_string(&self) -> String { + self.0.to_string() + } + + #[gen_stub(skip)] + fn __repr__(&self) -> String { + format!("Multiaddr({})", self.0) + } + + #[gen_stub(skip)] + fn __str__(&self) -> String { + self.to_string() + } +} + +pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + Ok(()) +} diff --git a/rust/exo_pyo3_bindings/tests/dummy.rs b/rust/exo_pyo3_bindings/tests/dummy.rs new file mode 100644 index 00000000..7d1ce0e4 --- /dev/null +++ b/rust/exo_pyo3_bindings/tests/dummy.rs @@ -0,0 +1,54 @@ +#[cfg(test)] +mod tests { + use core::mem::drop; + use core::option::Option::Some; + use core::time::Duration; + use tokio; + use tokio::sync::mpsc; + + #[tokio::test] + async fn test_drop_channel() { + struct Ping; + + let (tx, mut rx) = mpsc::channel::(10); + + let _ = tokio::spawn(async move { + println!("TASK: entered"); + + loop { + tokio::select! { + result = rx.recv() => { + match result { + Some(_) => { + println!("TASK: pinged"); + } + None => { + println!("TASK: closing channel"); + break; + } + } + } + _ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => { + println!("TASK: heartbeat"); + } + } + } + + println!("TASK: exited"); + }); + + let tx2 = tx.clone(); + + tokio::time::sleep(Duration::from_secs_f32(0.11)).await; + + tx.send(Ping).await.expect("Should not fail"); + drop(tx); + + tokio::time::sleep(Duration::from_secs_f32(0.11)).await; + + tx2.send(Ping).await.expect("Should not fail"); + drop(tx2); + + tokio::time::sleep(Duration::from_secs_f32(0.11)).await; + } +} diff --git a/rust/exo_pyo3_bindings/tests/test_python.py b/rust/exo_pyo3_bindings/tests/test_python.py new file mode 100644 index 00000000..ce5a676f --- /dev/null +++ b/rust/exo_pyo3_bindings/tests/test_python.py @@ -0,0 +1,34 @@ +import asyncio + +import pytest +from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError + + +@pytest.mark.asyncio +async def test_sleep_on_multiple_items() -> None: + print("PYTHON: starting handle") + h = NetworkingHandle(Keypair.generate_ed25519()) + + ct = asyncio.create_task(_await_cons(h)) + mt = asyncio.create_task(_await_msg(h)) + + # sleep for 4 ticks + for i in range(4): + await asyncio.sleep(1) + + try: + await h.gossipsub_publish("topic", b"somehting or other") + except NoPeersSubscribedToTopicError as e: + print("caught it", e) + + +async def _await_cons(h: NetworkingHandle): + while True: + c = await h.connection_update_recv() + print(f"PYTHON: connection update: {c}") + + +async def _await_msg(h: NetworkingHandle): + while True: + m = await h.gossipsub_recv() + print(f"PYTHON: message: {m}") diff --git a/rust/networking/Cargo.toml b/rust/networking/Cargo.toml new file mode 100644 index 00000000..47d61f41 --- /dev/null +++ b/rust/networking/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "networking" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +name = "networking" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +# datastructures +either = { workspace = true } + +# macro dependencies +extend = { workspace = true } +delegate = { workspace = true } +impl-trait-for-tuples = { workspace = true } +derive_more = { workspace = true } + +# async +tokio = { workspace = true, features = ["full"] } +futures = { workspace = true } +futures-timer = { workspace = true } + +# utility dependencies +util = { workspace = true } +thiserror = { workspace = true } +#internment = { workspace = true } +#recursion = { workspace = true } +#generativity = { workspace = true } +#itertools = { workspace = true } +tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] } +keccak-const = { workspace = true } + +# tracing/logging +log = { workspace = true } + +# networking +libp2p = { workspace = true, features = ["full"] } \ No newline at end of file diff --git a/rust/networking/examples/chatroom.rs b/rust/networking/examples/chatroom.rs new file mode 100644 index 00000000..3371b46d --- /dev/null +++ b/rust/networking/examples/chatroom.rs @@ -0,0 +1,74 @@ +use futures::stream::StreamExt as _; +use libp2p::{gossipsub, identity, swarm::SwarmEvent}; +use networking::{discovery, swarm}; +use tokio::{io, io::AsyncBufReadExt as _, select}; +use tracing_subscriber::EnvFilter; +use tracing_subscriber::filter::LevelFilter; + +#[tokio::main] +async fn main() { + let _ = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into())) + .try_init(); + + // Configure swarm + let mut swarm = + swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed"); + + // Create a Gossipsub topic & subscribe + let topic = gossipsub::IdentTopic::new("test-net"); + swarm + .behaviour_mut() + .gossipsub + .subscribe(&topic) + .expect("Subscribing to topic failed"); + + // Read full lines from stdin + let mut stdin = io::BufReader::new(io::stdin()).lines(); + println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub"); + + // Kick it off + loop { + select! { + // on gossipsub outgoing + Ok(Some(line)) = stdin.next_line() => { + if let Err(e) = swarm + .behaviour_mut().gossipsub + .publish(topic.clone(), line.as_bytes()) { + println!("Publish error: {e:?}"); + } + } + event = swarm.select_next_some() => match event { + // on gossipsub incoming + SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message { + propagation_source: peer_id, + message_id: id, + message, + })) => println!( + "\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n", + String::from_utf8_lossy(&message.data), + ), + + // on discovery + SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e { + discovery::Event::ConnectionEstablished { + peer_id, connection_id, remote_ip, remote_tcp_port + } => { + println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n"); + } + discovery::Event::ConnectionClosed { + peer_id, connection_id, remote_ip, remote_tcp_port + } => { + eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n"); + } + } + + // ignore outgoing errors: those are normal + e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); } + + // otherwise log any other event + e => { log::info!("Other event {e:?}"); } + } + } + } +} diff --git a/rust/networking/examples/chatroom_manual.rs b/rust/networking/examples/chatroom_manual.rs new file mode 100644 index 00000000..6c1ffd88 --- /dev/null +++ b/rust/networking/examples/chatroom_manual.rs @@ -0,0 +1,130 @@ +// Copyright 2018 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::{ + error::Error, + hash::{Hash}, +}; +use std::time::Duration; +use futures::stream::StreamExt; +use libp2p::{ + gossipsub, mdns, noise, + swarm::{NetworkBehaviour, SwarmEvent}, + tcp, yamux, +}; +use tokio::{io, io::AsyncBufReadExt, select}; +use tracing_subscriber::EnvFilter; + +// We create a custom network behaviour that combines Gossipsub and Mdns. +#[derive(NetworkBehaviour)] +struct MyBehaviour { + gossipsub: gossipsub::Behaviour, + mdns: mdns::tokio::Behaviour, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let _ = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .try_init(); + + let mut swarm = libp2p::SwarmBuilder::with_new_identity() + .with_tokio() + .with_tcp( + tcp::Config::default(), + noise::Config::new, + yamux::Config::default, + )? + .with_behaviour(|key| { + // Set a custom gossipsub configuration + let gossipsub_config = gossipsub::ConfigBuilder::default() + .heartbeat_interval(Duration::from_secs(10)) + .validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing) + .build() + .map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`. + + // build a gossipsub network behaviour + let gossipsub = gossipsub::Behaviour::new( + gossipsub::MessageAuthenticity::Signed(key.clone()), + gossipsub_config, + )?; + + let mdns = + mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?; + Ok(MyBehaviour { gossipsub, mdns }) + })? + .build(); + + println!("Running swarm with identity {}", swarm.local_peer_id()); + + // Create a Gossipsub topic + let topic = gossipsub::IdentTopic::new("test-net"); + // subscribes to our topic + swarm.behaviour_mut().gossipsub.subscribe(&topic)?; + + // Read full lines from stdin + let mut stdin = io::BufReader::new(io::stdin()).lines(); + + // Listen on all interfaces and whatever port the OS assigns + swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?; + + println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub"); + + // Kick it off + loop { + select! { + Ok(Some(line)) = stdin.next_line() => { + if let Err(e) = swarm + .behaviour_mut().gossipsub + .publish(topic.clone(), line.as_bytes()) { + println!("Publish error: {e:?}"); + } + } + event = swarm.select_next_some() => match event { + SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => { + for (peer_id, multiaddr) in list { + println!("mDNS discovered a new peer: {peer_id} on {multiaddr}"); + swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id); + } + }, + SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => { + for (peer_id, multiaddr) in list { + println!("mDNS discover peer has expired: {peer_id} on {multiaddr}"); + swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id); + } + }, + SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message { + propagation_source: peer_id, + message_id: id, + message, + })) => println!( + "Got message: '{}' with id: {id} from peer: {peer_id}", + String::from_utf8_lossy(&message.data), + ), + SwarmEvent::NewListenAddr { address, .. } => { + println!("Local node is listening on {address}"); + } + e => { + println!("Other swarm event: {:?}", e); + } + } + } + } +} \ No newline at end of file diff --git a/rust/networking/src/RESEARCH_NOTES.txt b/rust/networking/src/RESEARCH_NOTES.txt new file mode 100644 index 00000000..2beeca57 --- /dev/null +++ b/rust/networking/src/RESEARCH_NOTES.txt @@ -0,0 +1,44 @@ +https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609 +^raw sockets distributed -> `` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html +--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets. +--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received. +--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools. +--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html +----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET + +https://stackoverflow.com/questions/17169298/af-packet-on-osx +^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only +^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing + +https://www.unix.com/man_page/mojave/4/ip/ +^OS X manpages for IP + +https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts +^driver kit, system extensions & kexts for macOS + +---- + +To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network. +--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a +----> here is a guide on how to set up thunderbolt-ethernet on linux +----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS + +https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7 +^GPT discussion about making socket interface + +https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,?? +https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2 +^GPT discussion about accessing TB on MacOS low level interactions + +-------------------------------- + +https://www.intel.com/content/www/us/en/support/articles/000098893/software.html +^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge + + +--------------------------------- + +https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/ +-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet?? +-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!! +-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c \ No newline at end of file diff --git a/rust/networking/src/discovery.rs b/rust/networking/src/discovery.rs new file mode 100644 index 00000000..64a297c3 --- /dev/null +++ b/rust/networking/src/discovery.rs @@ -0,0 +1,379 @@ +use crate::keep_alive; +use delegate::delegate; +use either::Either; +use futures::FutureExt; +use futures_timer::Delay; +use libp2p::core::transport::PortUse; +use libp2p::core::{ConnectedPoint, Endpoint}; +use libp2p::swarm::behaviour::ConnectionEstablished; +use libp2p::swarm::dial_opts::DialOpts; +use libp2p::swarm::{dummy, CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm}; +use libp2p::{Multiaddr, PeerId, identity, mdns}; +use std::collections::{BTreeSet, HashMap}; +use std::convert::Infallible; +use std::io; +use std::net::IpAddr; +use std::task::{Context, Poll}; +use std::time::Duration; +use util::wakerdeque::WakerDeque; +use crate::ext::MultiaddrExt; + + +const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5); + +mod managed { + use std::io; + use std::time::Duration; + use libp2p::{identity, mdns, ping}; + use libp2p::swarm::NetworkBehaviour; + + const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500); + const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500); + const PING_TIMEOUT: Duration = Duration::from_millis(2_500); + const PING_INTERVAL: Duration = Duration::from_millis(2_500); + + #[derive(NetworkBehaviour)] + pub struct Behaviour { + mdns: mdns::tokio::Behaviour, + ping: ping::Behaviour, + } + + impl Behaviour { + pub fn new(keypair: &identity::Keypair) -> io::Result { + Ok(Self { + mdns: mdns_behaviour(keypair)?, + ping: ping_behaviour(), + }) + } + } + + fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result { + use mdns::{Config, tokio}; + + // mDNS config => enable IPv6 + let mdns_config = Config { + ttl: MDNS_RECORD_TTL, + query_interval: MDNS_QUERY_INTERVAL, + + // enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work + ..Default::default() + }; + + let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id()); + Ok(mdns_behaviour?) + } + + fn ping_behaviour() -> ping::Behaviour { + ping::Behaviour::new(ping::Config::new().with_timeout(PING_TIMEOUT).with_interval(PING_INTERVAL)) + } +} + +/// Events for when a listening connection is truly established and truly closed. +#[derive(Debug, Clone)] +pub enum Event { + ConnectionEstablished { + peer_id: PeerId, + connection_id: ConnectionId, + remote_ip: IpAddr, + remote_tcp_port: u16, + }, + ConnectionClosed { + peer_id: PeerId, + connection_id: ConnectionId, + remote_ip: IpAddr, + remote_tcp_port: u16, + }, +} + +/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections. +/// +/// The behaviour operates as such: +/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events +/// to the swarm. +/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed +/// immediately, and expired but connected peers are disconnected from immediately. +/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but +/// connected peers are disconnected from. +pub struct Behaviour { + // state-tracking for managed behaviors & mDNS-discovered peers + managed: managed::Behaviour, + mdns_discovered: HashMap>, + + retry_delay: Delay, // retry interval + + // pending events to emmit => waker-backed Deque to control polling + pending_events: WakerDeque>, +} + +impl Behaviour { + pub fn new(keypair: &identity::Keypair) -> io::Result { + Ok(Self { + managed: managed::Behaviour::new(keypair)?, + mdns_discovered: HashMap::new(), + retry_delay: Delay::new(RETRY_CONNECT_INTERVAL), + pending_events: WakerDeque::new(), + }) + } + + fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) { + self.pending_events.push_back(ToSwarm::Dial { + opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(), + }) + } + + fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) { + // push front to make this IMMEDIATE + self.pending_events.push_front(ToSwarm::CloseConnection { + peer_id, + connection: CloseConnection::One(connection), + }) + } + + + fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) { + for (p, ma) in peers { + self.dial(p, ma.clone()); // always connect + + // get peer's multi-addresses or insert if missing + let Some(mas) = self.mdns_discovered.get_mut(&p) else { + self.mdns_discovered.insert(p, BTreeSet::from([ma])); + continue; + }; + + // multiaddress should never already be present - else something has gone wrong + let is_new_addr = mas.insert(ma); + assert!(is_new_addr, "cannot discover a discovered peer"); + } + } + + fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) { + for (p, ma) in peers { + // at this point, we *must* have the peer + let mas = self + .mdns_discovered + .get_mut(&p) + .expect("nonexistent peer cannot expire"); + + // at this point, we *must* have the multiaddress + let was_present = mas.remove(&ma); + assert!(was_present, "nonexistent multiaddress cannot expire"); + + // if empty, remove the peer-id entirely + if mas.is_empty() { + self.mdns_discovered.remove(&p); + } + } + } + + fn on_connection_established( + &mut self, + peer_id: PeerId, + connection_id: ConnectionId, + remote_ip: IpAddr, + remote_tcp_port: u16, + ) { + // send out connected event + self.pending_events + .push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished { + peer_id, + connection_id, + remote_ip, + remote_tcp_port, + })); + } + + fn on_connection_closed( + &mut self, + peer_id: PeerId, + connection_id: ConnectionId, + remote_ip: IpAddr, + remote_tcp_port: u16, + ) { + // send out disconnected event + self.pending_events + .push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed { + peer_id, + connection_id, + remote_ip, + remote_tcp_port, + })); + } +} + +impl NetworkBehaviour for Behaviour { + type ConnectionHandler = + ConnectionHandlerSelect>; + type ToSwarm = Event; + + // simply delegate to underlying mDNS behaviour + + delegate! { + to self.managed { + fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>; + fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option, addresses: &[Multiaddr], effective_role: Endpoint) -> Result, ConnectionDenied>; + } + } + + fn handle_established_inbound_connection( + &mut self, + connection_id: ConnectionId, + peer: PeerId, + local_addr: &Multiaddr, + remote_addr: &Multiaddr, + ) -> Result, ConnectionDenied> { + Ok(ConnectionHandler::select( + dummy::ConnectionHandler, + self.managed.handle_established_inbound_connection( + connection_id, + peer, + local_addr, + remote_addr, + )?, + )) + } + + #[allow(clippy::needless_question_mark)] + fn handle_established_outbound_connection( + &mut self, + connection_id: ConnectionId, + peer: PeerId, + addr: &Multiaddr, + role_override: Endpoint, + port_use: PortUse, + ) -> Result, ConnectionDenied> { + Ok(ConnectionHandler::select( + dummy::ConnectionHandler, + self.managed.handle_established_outbound_connection( + connection_id, + peer, + addr, + role_override, + port_use, + )?, + )) + } + + fn on_connection_handler_event( + &mut self, + peer_id: PeerId, + connection_id: ConnectionId, + event: THandlerOutEvent, + ) { + match event { + Either::Left(ev) => libp2p::core::util::unreachable(ev), + Either::Right(ev) => self.managed.on_connection_handler_event( + peer_id, + connection_id, + ev, + ), + } + } + + // hook into these methods to drive behavior + + fn on_swarm_event(&mut self, event: FromSwarm) { + self.managed.on_swarm_event(event); // let mDNS handle swarm events + + // handle swarm events to update internal state: + match event { + FromSwarm::ConnectionEstablished(ConnectionEstablished { + peer_id, + connection_id, + endpoint, + .. + }) => { + let remote_address = match endpoint { + ConnectedPoint::Dialer { address, .. } => address, + ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr, + }; + + if let Some((ip, port)) = remote_address.try_to_tcp_addr() { + // handle connection established event which is filtered correctly + self.on_connection_established(peer_id, connection_id, ip, port) + } + } + FromSwarm::ConnectionClosed(ConnectionClosed { + peer_id, + connection_id, + endpoint, + .. + }) => { + let remote_address = match endpoint { + ConnectedPoint::Dialer { address, .. } => address, + ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr, + }; + + if let Some((ip, port)) = remote_address.try_to_tcp_addr() { + // handle connection closed event which is filtered correctly + self.on_connection_closed(peer_id, connection_id, ip, port) + } + } + + // since we are running TCP/IP transport layer, we are assuming that + // no address changes can occur, hence encountering one is a fatal error + FromSwarm::AddressChange(a) => { + unreachable!("unhandlable: address change encountered: {:?}", a) + } + _ => {} + } + } + + fn poll(&mut self, cx: &mut Context) -> Poll>> { + // delegate to managed behaviors for any behaviors they need to perform + match self.managed.poll(cx) { + Poll::Ready(ToSwarm::GenerateEvent(e)) => { + match e { + // handle discovered and expired events from mDNS + managed::BehaviourEvent::Mdns(e) => match e.clone() { + mdns::Event::Discovered(peers) => { + self.handle_mdns_discovered(peers); + } + mdns::Event::Expired(peers) => { + self.handle_mdns_expired(peers); + } + } + + // handle ping events => if error then disconnect + managed::BehaviourEvent::Ping(e) => { + if let Err(_) = e.result { + self.close_connection(e.peer, e.connection.clone()) + } + } + } + + // since we just consumed an event, we should immediately wake just in case + // there are more events to come where that came from + cx.waker().wake_by_ref(); + } + + + // forward any other mDNS event to the swarm or its connection handler(s) + Poll::Ready(e) => { + return Poll::Ready( + e.map_out(|_| unreachable!("events returning to swarm already handled")) + .map_in(Either::Right), + ); + } + + Poll::Pending => {} + } + + // retry connecting to all mDNS peers periodically (fails safely if already connected) + if self.retry_delay.poll_unpin(cx).is_ready() { + for (p, mas) in self.mdns_discovered.clone() { + for ma in mas { + self.dial(p, ma) + } + } + self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout + } + + // send out any pending events from our own service + if let Some(e) = self.pending_events.pop_front(cx) { + return Poll::Ready(e.map_in(Either::Left)); + } + + // wait for pending events + Poll::Pending + } +} diff --git a/rust/networking/src/keep_alive.rs b/rust/networking/src/keep_alive.rs new file mode 100644 index 00000000..eb67aecb --- /dev/null +++ b/rust/networking/src/keep_alive.rs @@ -0,0 +1,44 @@ +use delegate::delegate; +use libp2p::swarm::handler::ConnectionEvent; +use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler}; +use std::task::{Context, Poll}; + +/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps +/// the connection alive. +#[derive(Clone)] +#[repr(transparent)] +pub struct ConnectionHandler(dummy::ConnectionHandler); + +impl ConnectionHandler { + pub fn new() -> Self { + ConnectionHandler(dummy::ConnectionHandler) + } +} + +impl handler::ConnectionHandler for ConnectionHandler { + // delegate types and implementation mostly to dummy handler + type FromBehaviour = ::FromBehaviour; + type ToBehaviour = ::ToBehaviour; + type InboundProtocol = + ::InboundProtocol; + type OutboundProtocol = + ::OutboundProtocol; + type InboundOpenInfo = + ::InboundOpenInfo; + type OutboundOpenInfo = + ::OutboundOpenInfo; + + delegate! { + to self.0 { + fn listen_protocol(&self) -> SubstreamProtocol; + fn poll(&mut self, cx: &mut Context<'_>) -> Poll>; + fn on_behaviour_event(&mut self, event: Self::FromBehaviour); + fn on_connection_event(&mut self, event: ConnectionEvent); + } + } + + // specifically override this to force connection to stay alive + fn connection_keep_alive(&self) -> bool { + true + } +} diff --git a/rust/networking/src/lib.rs b/rust/networking/src/lib.rs new file mode 100644 index 00000000..a83bdc71 --- /dev/null +++ b/rust/networking/src/lib.rs @@ -0,0 +1,64 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation +//! +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +// #![feature(stmt_expr_attributes)] +// #![feature(unboxed_closures)] +// #![feature(assert_matches)] +// #![feature(async_fn_in_dyn_trait)] +// #![feature(async_for_loop)] +// #![feature(auto_traits)] +// #![feature(negative_impls)] + +pub mod discovery; +pub mod keep_alive; +pub mod swarm; + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias { + use std::error::Error; + + pub type AnyError = Box; + pub type AnyResult = Result; +} + +/// Namespace for crate-wide extension traits/methods +pub(crate) mod ext { + use std::net::IpAddr; + use extend::ext; + use libp2p::Multiaddr; + use libp2p::multiaddr::Protocol; + + #[ext(pub, name = MultiaddrExt)] + impl Multiaddr { + /// If the multiaddress corresponds to a TCP address, extracts it + fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> { + let mut ps = self.into_iter(); + let ip = if let Some(p) = ps.next() { + match p { + Protocol::Ip4(ip) => IpAddr::V4(ip), + Protocol::Ip6(ip) => IpAddr::V6(ip), + _ => return None + } + } else { + return None; + }; + let Some(Protocol::Tcp(port)) = ps.next() else { + return None; + }; + Some((ip, port)) + } + } +} + +pub(crate) mod private { + #![allow(dead_code)] + + /// Sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} \ No newline at end of file diff --git a/rust/networking/src/swarm.rs b/rust/networking/src/swarm.rs new file mode 100644 index 00000000..24750558 --- /dev/null +++ b/rust/networking/src/swarm.rs @@ -0,0 +1,133 @@ +use crate::alias; +use crate::swarm::transport::tcp_transport; +pub use behaviour::{Behaviour, BehaviourEvent}; +use libp2p::{SwarmBuilder, identity}; + +pub type Swarm = libp2p::Swarm; + +/// The current version of the network: this prevents devices running different versions of the +/// software from interacting with each other. +/// +/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should +/// even be, and how to inject the right version into this config/initialization. E.g. should +/// this be passed in as a parameter? What about rapidly changing versions in debug builds? +/// this is all VERY very hard to figure out and needs to be mulled over as a team. +pub const NETWORK_VERSION: &[u8] = b"v0.0.1"; + +/// Create and configure a swarm which listens to all ports on OS +pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult { + let mut swarm = SwarmBuilder::with_existing_identity(keypair) + .with_tokio() + .with_other_transport(tcp_transport)? + .with_behaviour(Behaviour::new)? + .build(); + + // Listen on all interfaces and whatever port the OS assigns + swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?; + Ok(swarm) +} + +mod transport { + use crate::alias; + use crate::swarm::NETWORK_VERSION; + use futures::{AsyncRead, AsyncWrite}; + use keccak_const::Sha3_256; + use libp2p::core::muxing; + use libp2p::core::transport::Boxed; + use libp2p::pnet::{PnetError, PnetOutput}; + use libp2p::{PeerId, Transport, identity, noise, pnet, yamux}; + + /// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`]. + /// See [`pnet_upgrade`] for more. + const PNET_PRESHARED_KEY: [u8; 32] = Sha3_256::new() + .update(b"exo_discovery_network") + .update(NETWORK_VERSION) + .finalize(); + + /// Make the Swarm run on a private network, as to not clash with public libp2p nodes and + /// also different-versioned instances of this same network. + /// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers. + async fn pnet_upgrade( + socket: TSocket, + _: impl Sized, + ) -> Result, PnetError> + where + TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + use pnet::{PnetConfig, PreSharedKey}; + PnetConfig::new(PreSharedKey::new(PNET_PRESHARED_KEY)) + .handshake(socket) + .await + } + + /// TCP/IP transport layer configuration. + pub fn tcp_transport( + keypair: &identity::Keypair, + ) -> alias::AnyResult> { + use libp2p::{ + core::upgrade::Version, + tcp::{Config, tokio}, + }; + + // `TCP_NODELAY` enabled => avoid latency + let tcp_config = Config::default().nodelay(true); + + // V1 + lazy flushing => 0-RTT negotiation + let upgrade_version = Version::V1Lazy; + + // Noise is faster than TLS + we don't care much for security + let noise_config = noise::Config::new(keypair)?; + + // Use default Yamux config for multiplexing + let yamux_config = yamux::Config::default(); + + // Create new Tokio-driven TCP/IP transport layer + let base_transport = tokio::Transport::new(tcp_config) + .and_then(pnet_upgrade) + .upgrade(upgrade_version) + .authenticate(noise_config) + .multiplex(yamux_config); + + // Return boxed transport (to flatten complex type) + Ok(base_transport.boxed()) + } +} + +mod behaviour { + use crate::{alias, discovery}; + use libp2p::swarm::NetworkBehaviour; + use libp2p::{gossipsub, identity}; + + /// Behavior of the Swarm which composes all desired behaviors: + /// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`]. + #[derive(NetworkBehaviour)] + pub struct Behaviour { + pub discovery: discovery::Behaviour, + pub gossipsub: gossipsub::Behaviour, + } + + impl Behaviour { + pub fn new(keypair: &identity::Keypair) -> alias::AnyResult { + Ok(Self { + discovery: discovery::Behaviour::new(keypair)?, + gossipsub: gossipsub_behaviour(keypair), + }) + } + } + + fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour { + use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode}; + + // build a gossipsub network behaviour + // => signed message authenticity + strict validation mode means the message-ID is + // automatically provided by gossipsub w/out needing to provide custom message-ID function + gossipsub::Behaviour::new( + MessageAuthenticity::Signed(keypair.clone()), + ConfigBuilder::default() + .validation_mode(ValidationMode::Strict) + .build() + .expect("the configuration should always be valid"), + ) + .expect("creating gossipsub behavior should always work") + } +} diff --git a/rust/networking/tests/dummy.rs b/rust/networking/tests/dummy.rs new file mode 100644 index 00000000..ddaa8cc2 --- /dev/null +++ b/rust/networking/tests/dummy.rs @@ -0,0 +1,7 @@ +// maybe this will hold test in the future...?? + +#[cfg(test)] +mod tests { + #[test] + fn does_nothing() {} +} diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 00000000..271800cb --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" \ No newline at end of file diff --git a/rust/system_custodian/Cargo.toml b/rust/system_custodian/Cargo.toml new file mode 100644 index 00000000..46e530b1 --- /dev/null +++ b/rust/system_custodian/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "system_custodian" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +name = "system_custodian" +path = "src/lib.rs" + +[[bin]] +path = "src/bin/main.rs" +name = "system_custodian" +doc = false + +[lints] +workspace = true + +[dependencies] +# datastructures +either = { workspace = true } + +# macro dependencies +extend = { workspace = true } +delegate = { workspace = true } +impl-trait-for-tuples = { workspace = true } +derive_more = { workspace = true } + +# async +tokio = { workspace = true, features = ["full"] } +futures = { workspace = true } +futures-timer = { workspace = true } + +# utility dependencies +util = { workspace = true } +thiserror = { workspace = true } +#internment = { workspace = true } +#recursion = { workspace = true } +#generativity = { workspace = true } +#itertools = { workspace = true } +tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] } +keccak-const = { workspace = true } + +# tracing/logging +log = { workspace = true } + diff --git a/rust/system_custodian/src/bin/main.rs b/rust/system_custodian/src/bin/main.rs new file mode 100644 index 00000000..2345c633 --- /dev/null +++ b/rust/system_custodian/src/bin/main.rs @@ -0,0 +1,4 @@ +//! TODO: documentation +//! + +fn main() {} diff --git a/rust/system_custodian/src/lib.rs b/rust/system_custodian/src/lib.rs new file mode 100644 index 00000000..cf856239 --- /dev/null +++ b/rust/system_custodian/src/lib.rs @@ -0,0 +1,69 @@ +//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon. +//! +//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the +//! launch of the Exo application, and responsible for ensuring the system (configuration, settings, +//! etc.) is in an appropriate state to facilitate the running of Exo application. +//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/) +//! service which Exo application use to _control & query_ it. +//! +//! # Lifecycle +//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the +//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When +//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application; +//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were +//! destructive to the user's pre-existing configurations. +//! +//! # Responsibilities +//! TODO: these are purely on MacOS, but change to be more broad +//! The **_System Custodian_** daemon is responsible for using System Configuration framework to +//! 1. duplicate the current network set +//! 2. modify existing services to turn on IPv6 if not there +//! 3. remove any bridge services & add any missing services that AREN'T bridge +//! TODO: In the future: +//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html) +//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are) +//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland +//! logic, this would be the place to spin that up. +//! +//! Then it will watch the SCDynamicStore for: +//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC +//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus +//! interface of any changes +//! 2. watch for any __undesirable__ changes to configuration and revert it +//! +//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on +//! each of the interfaces & also listen to/query for any changes on the OS routing cache?? +//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!! +//! 1. all that info should coalesce back to the overall state colleted -> should be queryable +//! over D-Bus +//! TODO: +//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can +//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the +//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows... +//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication, +//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/) +//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices), +//! then this would be the place to carry out discovery and propper handshakes with devices +//! on the other end of the link. +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +#![feature(stmt_expr_attributes)] +#![feature(type_alias_impl_trait)] +#![feature(specialization)] +#![feature(unboxed_closures)] +#![feature(const_trait_impl)] +#![feature(fn_traits)] + +pub(crate) mod private { + // sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias {} + +/// Namespace for crate-wide extension traits/methods +pub(crate) mod ext {} diff --git a/rust/util/Cargo.toml b/rust/util/Cargo.toml new file mode 100644 index 00000000..aeae3534 --- /dev/null +++ b/rust/util/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "util" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[lib] +doctest = false +name = "util" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +# macro dependencies +extend = { workspace = true } + +# utility dependencies +thiserror = { workspace = true } +once_cell = { workspace = true } +internment = { workspace = true } +derive_more = { workspace = true } +bon = { workspace = true } +recursion = { workspace = true } diff --git a/rust/util/src/lib.rs b/rust/util/src/lib.rs new file mode 100644 index 00000000..60e11f3a --- /dev/null +++ b/rust/util/src/lib.rs @@ -0,0 +1,53 @@ +//! TODO: crate documentation +//! +//! this is here as a placeholder documentation +//! +//! + +// enable Rust-unstable features for convenience +#![feature(trait_alias)] +#![feature(stmt_expr_attributes)] +#![feature(type_alias_impl_trait)] +#![feature(specialization)] +#![feature(unboxed_closures)] +#![feature(const_trait_impl)] +#![feature(fn_traits)] + +pub mod nonempty; +pub mod wakerdeque; + +pub(crate) mod private { + // sealed traits support + pub trait Sealed {} + impl Sealed for T {} +} + +/// Namespace for all the type/trait aliases used by this crate. +pub(crate) mod alias {} + +/// Namespace for crate-wide extension traits/methods +pub mod ext { + use extend::ext; + + #[ext(pub, name = BoxedSliceExt)] + impl Box<[T]> { + #[inline] + fn map(self, f: F) -> Box<[B]> + where + F: FnMut(T) -> B, + { + self.into_iter().map(f).collect() + } + } + + #[ext(pub, name = VecExt)] + impl Vec { + #[inline] + fn map(self, f: F) -> Vec + where + F: FnMut(T) -> B, + { + self.into_iter().map(f).collect() + } + } +} diff --git a/rust/util/src/nonempty.rs b/rust/util/src/nonempty.rs new file mode 100644 index 00000000..e9eb8620 --- /dev/null +++ b/rust/util/src/nonempty.rs @@ -0,0 +1,138 @@ +use std::slice::SliceIndex; +use std::{ops, slice}; +use thiserror::Error; + +#[derive(Error, Debug)] +#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")] +pub struct EmptySliceError; + +/// A pointer to a non-empty fixed-size slice allocated on the heap. +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[repr(transparent)] +pub struct NonemptyArray(Box<[T]>); + +#[allow(clippy::arbitrary_source_item_ordering)] +impl NonemptyArray { + #[inline] + pub fn singleton(value: T) -> Self { + Self(Box::new([value])) + } + + #[allow(clippy::missing_errors_doc)] + #[inline] + pub fn try_from_boxed_slice>>( + boxed_slice: S, + ) -> Result { + let boxed_slice = boxed_slice.into(); + if boxed_slice.is_empty() { + Err(EmptySliceError) + } else { + Ok(Self(boxed_slice)) + } + } + + #[must_use] + #[inline] + pub fn into_boxed_slice(self) -> Box<[T]> { + self.0 + } + + #[must_use] + #[inline] + pub fn to_vec(&self) -> Vec + where + T: Clone, + { + self.0.to_vec() + } + + #[must_use] + #[inline] + pub const fn as_slice(&self) -> &[T] { + &self.0 + } + + #[allow(clippy::indexing_slicing)] + #[must_use] + #[inline] + pub fn first(&self) -> &T { + &self.0[0] + } + + #[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)] + #[must_use] + #[inline] + pub fn last(&self) -> &T { + &self.0[self.0.len() - 1] + } + + #[must_use] + #[inline] + pub fn get(&self, index: I) -> Option<&I::Output> + where + I: SliceIndex<[T]>, + { + self.0.get(index) + } + + #[allow(clippy::len_without_is_empty)] + #[must_use] + #[inline] + pub const fn len(&self) -> usize { + self.0.len() + } + + #[allow(clippy::iter_without_into_iter)] + #[inline] + pub fn iter(&self) -> slice::Iter<'_, T> { + self.0.iter() + } + + #[allow(clippy::iter_without_into_iter)] + #[inline] + pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> { + self.0.iter_mut() + } + + #[inline] + #[must_use] + pub fn map U>(self, f: F) -> NonemptyArray { + NonemptyArray(self.0.into_iter().map(f).collect()) + } +} + +impl From> for Box<[T]> { + #[inline] + fn from(value: NonemptyArray) -> Self { + value.into_boxed_slice() + } +} + +impl ops::Index for NonemptyArray { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} + +impl IntoIterator for NonemptyArray { + type Item = T; + type IntoIter = std::vec::IntoIter; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.into_boxed_slice().into_vec().into_iter() + } +} + +impl<'a, T> IntoIterator for &'a NonemptyArray { + type Item = &'a T; + type IntoIter = slice::Iter<'a, T>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/rust/util/src/wakerdeque.rs b/rust/util/src/wakerdeque.rs new file mode 100644 index 00000000..336c0347 --- /dev/null +++ b/rust/util/src/wakerdeque.rs @@ -0,0 +1,55 @@ +use std::collections::VecDeque; +use std::fmt::{Debug, Formatter}; +use std::task::{Context, Waker}; + +/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods, +/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods. +pub struct WakerDeque { + waker: Option, + deque: VecDeque, +} + +impl Debug for WakerDeque { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.deque.fmt(f) + } +} + +impl WakerDeque { + pub fn new() -> Self { + Self { + waker: None, + deque: VecDeque::new(), + } + } + + fn update(&mut self, cx: &mut Context<'_>) { + self.waker = Some(cx.waker().clone()); + } + + fn wake(&mut self) { + let Some(ref mut w) = self.waker else { return }; + w.wake_by_ref(); + self.waker = None; + } + + pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option { + self.update(cx); + self.deque.pop_front() + } + + pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option { + self.update(cx); + self.deque.pop_back() + } + + pub fn push_front(&mut self, value: T) { + self.wake(); + self.deque.push_front(value); + } + + pub fn push_back(&mut self, value: T) { + self.wake(); + self.deque.push_back(value); + } +} diff --git a/src/exo/engines/mlx/__init__.py b/src/exo/engines/mlx/__init__.py index 3672ffac..716ee0b9 100644 --- a/src/exo/engines/mlx/__init__.py +++ b/src/exo/engines/mlx/__init__.py @@ -8,9 +8,10 @@ import mlx.nn as nn # type: ignore # These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is. # For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function + class Model(nn.Module): layers: list[nn.Module] - + def __call__(self, x: mx.array, cache: Optional[list[KVCache]]) -> mx.array: ... @@ -18,7 +19,7 @@ class Detokenizer: def reset(self) -> None: ... def add_token(self, token: int) -> None: ... def finalize(self) -> None: ... - + @property def last_segment(self) -> str: ... @@ -27,5 +28,5 @@ class TokenizerWrapper: bos_token: Optional[str] eos_token_ids: list[int] detokenizer: Detokenizer - - def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ... \ No newline at end of file + + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ... diff --git a/src/exo/engines/mlx/utils_mlx.py b/src/exo/engines/mlx/utils_mlx.py index 72b99584..774af661 100644 --- a/src/exo/engines/mlx/utils_mlx.py +++ b/src/exo/engines/mlx/utils_mlx.py @@ -29,6 +29,7 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) mlx_rank: None | int = None mlx_world_size: None | int = None + def mx_barrier(): mx.eval( # type: ignore mx.distributed.all_sum( @@ -36,6 +37,7 @@ def mx_barrier(): ) ) + def broadcast_from_zero(value: int) -> int: if mlx_rank is None: return value @@ -46,8 +48,9 @@ def broadcast_from_zero(value: int) -> int: a = mx.array([0], dtype=mx.int32) m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu)) - mx.eval(m) # type: ignore - return int(m.item()) # type: ignore + mx.eval(m) # type: ignore + return int(m.item()) # type: ignore + class HostList(RootModel[list[str]]): @classmethod @@ -83,7 +86,7 @@ def mlx_setup( if wired_frac_of_mrwss > 0.0: target_wired = int(wired_frac_of_mrwss * mrwss) target_wired = min(target_wired, target_cache) # don’t wire more than cache - + runner_print(f"{target_wired=}") with contextlib.suppress(Exception): # older macOS won’t have this mx.set_wired_limit(max(target_wired, 0)) @@ -136,14 +139,14 @@ def initialize_mlx( def shard_and_load( - model_shard_meta: ShardMetadata, + model_shard_meta: ShardMetadata, ) -> tuple[nn.Module, TokenizerWrapper]: model_path = build_model_path(model_shard_meta.model_meta.model_id) runner_print(f"loading model from {model_path}") model, config = load_model(model_path, lazy=True, strict=False) # type: ignore - runner_print(f'{config=}') + runner_print(f"{config=}") assert isinstance(model, nn.Module) tokenizer = load_tokenizer(model_path) @@ -154,7 +157,7 @@ def shard_and_load( # Synchronize processes before generation to avoid timeout mx_barrier() - return model, tokenizer # type: ignore + return model, tokenizer # type: ignore async def apply_chat_template( @@ -199,11 +202,13 @@ async def apply_chat_template( return prompt + class NullKVCache(KVCache): """ A KVCache that pretends to exist but holds zero tokens. It satisfies .state/.meta_state and never allocates real keys/values. """ + def __init__(self, dtype: mx.Dtype = mx.float16): super().__init__() # zero-length K/V so shapes/dtypes are defined but empty @@ -218,19 +223,21 @@ class NullKVCache(KVCache): @state.setter def state(self, v: tuple[mx.array, mx.array]) -> None: - raise NotImplementedError('We should not be setting a NullKVCache.') + raise NotImplementedError("We should not be setting a NullKVCache.") + async def make_kv_cache( model: Model, max_kv_size: Optional[int] = None, ) -> list[KVCache]: - assert hasattr(model, 'layers') - + assert hasattr(model, "layers") + return [ NullKVCache() if isinstance(layer, IdentityLayer) else KVCache() for layer in model.layers ] + def mlx_force_oom(size: int = 40000) -> None: """ Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations. diff --git a/src/exo/main.py b/src/exo/main.py index bbcc08c9..988a861b 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -1,41 +1,221 @@ import argparse -import multiprocessing as mp +from dataclasses import dataclass +from typing import Self +import anyio +from anyio.abc import TaskGroup from loguru import logger +from pydantic import PositiveInt -from exo.master.main import main as master_main +import exo.routing.topics as topics +from exo.master.api import API # TODO: should API be in master? +from exo.master.main import Master +from exo.routing.router import Router, get_node_id_keypair from exo.shared.constants import EXO_LOG +from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_setup -from exo.worker.main import main as worker_main +from exo.shared.types.common import NodeId +from exo.utils.channels import Receiver, channel +from exo.utils.pydantic_ext import CamelCaseModel +from exo.worker.download.impl_shard_downloader import exo_shard_downloader +from exo.worker.main import Worker + + +# TODO: Entrypoint refactor +# I marked this as a dataclass as I want trivial constructors. +# This is the collection of systems for our entire application. +@dataclass +class Node: + router: Router + worker: Worker + election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present. + election_result_receiver: Receiver[ElectionResult] + master: Master | None + api: API | None + + node_id: NodeId + _tg: TaskGroup | None = None + + @classmethod + async def create(cls, args: "Args") -> "Self": + keypair = get_node_id_keypair() + node_id = NodeId(keypair.to_peer_id().to_base58()) + router = Router.create(keypair) + await router.register_topic(topics.GLOBAL_EVENTS) + await router.register_topic(topics.LOCAL_EVENTS) + await router.register_topic(topics.COMMANDS) + await router.register_topic(topics.ELECTION_MESSAGES) + await router.register_topic(topics.CONNECTION_MESSAGES) + + logger.info(f"Starting node {node_id}") + if args.spawn_api: + api = API( + node_id=node_id, + port=args.api_port, + global_event_receiver=router.receiver(topics.GLOBAL_EVENTS), + command_sender=router.sender(topics.COMMANDS), + ) + else: + api = None + + worker = Worker( + node_id, + exo_shard_downloader(), + initial_connection_messages=[], + connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES), + global_event_receiver=router.receiver(topics.GLOBAL_EVENTS), + local_event_sender=router.sender(topics.LOCAL_EVENTS), + command_sender=router.sender(topics.COMMANDS), + ) + # We start every node with a master + master = Master( + node_id, + global_event_sender=router.sender(topics.GLOBAL_EVENTS), + local_event_receiver=router.receiver(topics.LOCAL_EVENTS), + command_receiver=router.receiver(topics.COMMANDS), + tb_only=args.tb_only, + ) + + # If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ. + er_send, er_recv = channel[ElectionResult]() + election = Election( + node_id, + seniority=1_000_000 if args.force_master else 0, + # nb: this DOES feedback right now. i have thoughts on how to address this, + # but ultimately it seems not worth the complexity + election_message_sender=router.sender(topics.ELECTION_MESSAGES), + election_message_receiver=router.receiver(topics.ELECTION_MESSAGES), + connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES), + election_result_sender=er_send, + ) + + return cls(router, worker, election, er_recv, master, api, node_id) + + async def run(self): + async with anyio.create_task_group() as tg: + self._tg = tg + tg.start_soon(self.router.run) + tg.start_soon(self.worker.run) + tg.start_soon(self.election.run) + if self.master: + tg.start_soon(self.master.run) + if self.api: + tg.start_soon(self.api.run) + tg.start_soon(self._elect_loop) + + async def _elect_loop(self): + assert self._tg + with self.election_result_receiver as results: + async for result in results: + # I don't like this duplication, but it's manageable for now. + # TODO: This function needs refactoring generally + + # Ok: + # On new master: + # - Elect master locally if necessary + # - Shutdown and re-create the worker + # - Shut down and re-create the API + + if result.node_id == self.node_id and self.master is not None: + logger.info("Node elected Master") + elif result.node_id == self.node_id and self.master is None: + logger.info("Node elected Master - promoting self") + self.master = Master( + self.node_id, + global_event_sender=self.router.sender(topics.GLOBAL_EVENTS), + local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS), + command_receiver=self.router.receiver(topics.COMMANDS), + ) + self._tg.start_soon(self.master.run) + elif result.node_id != self.node_id and self.master is not None: + logger.info(f"Node {result.node_id} elected master - demoting self") + await self.master.shutdown() + self.master = None + else: + logger.info(f"Node {result.node_id} elected master") + if result.is_new_master: + await anyio.sleep(0) + if self.worker: + self.worker.shutdown() + # TODO: add profiling etc to resource monitor + self.worker = Worker( + self.node_id, + exo_shard_downloader(), + initial_connection_messages=result.historic_messages, + connection_message_receiver=self.router.receiver( + topics.CONNECTION_MESSAGES + ), + global_event_receiver=self.router.receiver( + topics.GLOBAL_EVENTS + ), + local_event_sender=self.router.sender(topics.LOCAL_EVENTS), + command_sender=self.router.sender(topics.COMMANDS), + ) + self._tg.start_soon(self.worker.run) + if self.api: + self.api.reset() def main(): - parser = argparse.ArgumentParser(prog="exo") - parser.add_argument( - "-v", "--verbose", action="store_const", const=1, dest="verbosity", default=0 - ) - parser.add_argument( - "-vv", - "--very-verbose", - action="store_const", - const=2, - dest="verbosity", - default=0, - ) - args = parser.parse_args() - if type(args.verbosity) is not int: # type: ignore - raise TypeError("Verbosity was parsed incorrectly") + args = Args.parse() + # TODO: Refactor the current verbosity system logger_setup(EXO_LOG, args.verbosity) - logger.info("starting exo") + logger.info("Starting EXO") - # This is for future PyInstaller compatibility - mp.set_start_method("spawn", force=True) - - worker = mp.Process(target=worker_main, args=(EXO_LOG, args.verbosity)) - master = mp.Process(target=master_main, args=(EXO_LOG, args.verbosity)) - worker.start() - master.start() - worker.join() - master.join() + node = anyio.run(Node.create, args) + anyio.run(node.run) logger_cleanup() + + +class Args(CamelCaseModel): + verbosity: int = 0 + force_master: bool = False + spawn_api: bool = False + api_port: PositiveInt = 8000 + tb_only: bool = False + + @classmethod + def parse(cls) -> Self: + parser = argparse.ArgumentParser(prog="EXO") + default_verbosity = 0 + parser.add_argument( + "-q", + "--quiet", + action="store_const", + const=-1, + dest="verbosity", + default=default_verbosity, + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + dest="verbosity", + default=default_verbosity, + ) + parser.add_argument( + "-m", + "--force-master", + action="store_true", + dest="force_master", + ) + parser.add_argument( + "--no-api", + action="store_false", + dest="spawn_api", + ) + parser.add_argument( + "--api-port", + type=int, + dest="api_port", + default=8000, + ) + parser.add_argument( + "--tb-only", + action="store_true", + dest="tb_only", + ) + + args = parser.parse_args() + return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically diff --git a/src/exo/master/api.py b/src/exo/master/api.py index f37418a4..ebd66786 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -2,16 +2,18 @@ import asyncio import os import time from collections.abc import AsyncGenerator -from typing import Callable, List, Sequence, final +from typing import final import uvicorn +from anyio import create_task_group +from anyio.abc import TaskGroup from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from loguru import logger -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage +from exo.shared.apply import apply from exo.shared.models.model_cards import MODEL_CARDS from exo.shared.models.model_meta import get_model_meta from exo.shared.types.api import ( @@ -24,23 +26,26 @@ from exo.shared.types.api import ( ModelListModel, StreamingChoiceResponse, ) -from exo.shared.types.common import CommandId -from exo.shared.types.events import ChunkGenerated, Event -from exo.shared.types.events.chunks import TokenChunk -from exo.shared.types.events.commands import ( - ChatCompletionCommand, +from exo.shared.types.chunks import TokenChunk +from exo.shared.types.commands import ( + ChatCompletion, Command, - CommandType, - CreateInstanceCommand, - DeleteInstanceCommand, - TaskFinishedCommand, + CreateInstance, + DeleteInstance, + ForwarderCommand, + TaggedCommand, + # TODO: SpinUpInstance + TaskFinished, ) -from exo.shared.types.events.components import EventFromEventLog +from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent from exo.shared.types.models import ModelMetadata from exo.shared.types.state import State from exo.shared.types.tasks import ChatCompletionTaskParams from exo.shared.types.worker.common import InstanceId from exo.shared.types.worker.instances import Instance +from exo.utils.channels import Receiver, Sender +from exo.utils.event_buffer import OrderedBuffer def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse: @@ -70,26 +75,50 @@ async def resolve_model_meta(model_id: str) -> ModelMetadata: class API: def __init__( self, - command_buffer: List[Command], - global_events: AsyncSQLiteEventStorage, - get_state: Callable[[], State], + *, + node_id: NodeId, + port: int = 8000, + # Ideally this would be a MasterForwarderEvent but type system says no :( + global_event_receiver: Receiver[ForwarderEvent], + command_sender: Sender[ForwarderCommand], ) -> None: - self.get_state = get_state - self.command_buffer = command_buffer - self.global_events = global_events + self.state = State() + self.command_sender = command_sender + self.global_event_receiver = global_event_receiver + self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]() + self.node_id: NodeId = node_id + self.port = port - self._app = FastAPI() + self.app = FastAPI() self._setup_cors() self._setup_routes() - self._app.mount( + self.app.mount( "/", - StaticFiles(directory=os.environ["DASHBOARD_DIR"], html=True), + StaticFiles( + directory=os.environ.get( + "DASHBOARD_DIR", + os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../dashboard") + ), + ), + html=True, + ), name="dashboard", ) + self._chat_completion_queues: dict[ + CommandId, asyncio.Queue[ChunkGenerated] + ] = {} + self._tg: TaskGroup | None = None + + def reset(self): + self.state = State() + self.event_buffer = OrderedBuffer[Event]() + self._chat_completion_queues = {} + def _setup_cors(self) -> None: - self._app.add_middleware( + self.app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, @@ -98,23 +127,19 @@ class API: ) def _setup_routes(self) -> None: - self._app.post("/instance")(self.create_instance) - self._app.get("/instance/{instance_id}")(self.get_instance) - self._app.delete("/instance/{instance_id}")(self.delete_instance) - self._app.get("/models")(self.get_models) - self._app.get("/v1/models")(self.get_models) - self._app.post("/v1/chat/completions")(self.chat_completions) - self._app.get("/state")(self.get_state) - - @property - def app(self) -> FastAPI: - return self._app + self.app.post("/instance")(self.create_instance) + self.app.get("/instance/{instance_id}")(self.get_instance) + self.app.delete("/instance/{instance_id}")(self.delete_instance) + self.app.get("/models")(self.get_models) + self.app.get("/v1/models")(self.get_models) + self.app.post("/v1/chat/completions")(self.chat_completions) + self.app.get("/state")(lambda: self.state) async def create_instance( self, payload: CreateInstanceTaskParams ) -> CreateInstanceResponse: model_meta = await resolve_model_meta(payload.model_id) - required_memory_bytes = model_meta.storage_size_kilobytes * 1024 + required_memory_bytes = model_meta.storage_size.in_kb available_memory_bytes = self._calculate_total_available_memory() if required_memory_bytes > available_memory_bytes: @@ -123,37 +148,33 @@ class API: detail=f"Insufficient memory to create instance. Required: {required_memory_bytes // (1024**3):.1f}GB, Available: {available_memory_bytes // (1024**3):.1f}GB", ) - command = CreateInstanceCommand( + command = CreateInstance( command_id=CommandId(), - command_type=CommandType.CREATE_INSTANCE, model_meta=model_meta, - instance_id=InstanceId(), ) - self.command_buffer.append(command) + await self._send(command) return CreateInstanceResponse( message="Command received.", command_id=command.command_id, model_meta=model_meta, - instance_id=command.instance_id, ) def get_instance(self, instance_id: InstanceId) -> Instance: - state = self.get_state() + state = self.state if instance_id not in state.instances: raise HTTPException(status_code=404, detail="Instance not found") return state.instances[instance_id] - def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse: - if instance_id not in self.get_state().instances: + async def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse: + if instance_id not in self.state.instances: raise HTTPException(status_code=404, detail="Instance not found") - command = DeleteInstanceCommand( + command = DeleteInstance( command_id=CommandId(), - command_type=CommandType.DELETE_INSTANCE, instance_id=instance_id, ) - self.command_buffer.append(command) + await self._send(command) return DeleteInstanceResponse( message="Command received.", command_id=command.command_id, @@ -165,37 +186,27 @@ class API: ) -> AsyncGenerator[str, None]: """Generate chat completion stream as JSON strings.""" - events = await self.global_events.get_events_since(0) - prev_idx = await self.global_events.get_last_idx() + self._chat_completion_queues[command_id] = asyncio.Queue() finished = False while not finished: - await asyncio.sleep(0.01) + # TODO: how long should this timeout be? + chunk = await asyncio.wait_for( + self._chat_completion_queues[command_id].get(), timeout=60 + ) + if chunk.command_id == command_id: + assert isinstance(chunk.chunk, TokenChunk) + chunk_response: ChatCompletionResponse = chunk_to_response(chunk.chunk) + logger.debug(f"chunk_response: {chunk_response}") + yield f"data: {chunk_response.model_dump_json()}\n\n" - events: Sequence[ - EventFromEventLog[Event] - ] = await self.global_events.get_events_since(prev_idx) - # TODO: Can do this with some better functionality to tail event log into an AsyncGenerator. - prev_idx = events[-1].idx_in_log if events else prev_idx + if chunk.chunk.finish_reason is not None: + yield "data: [DONE]\n\n" + finished = True - for wrapped_event in events: - event = wrapped_event.event - if isinstance(event, ChunkGenerated) and event.command_id == command_id: - assert isinstance(event.chunk, TokenChunk) - chunk_response: ChatCompletionResponse = chunk_to_response( - event.chunk - ) - logger.debug(chunk_response) - yield f"data: {chunk_response.model_dump_json()}\n\n" - - if event.chunk.finish_reason is not None: - yield "data: [DONE]" - finished = True - - command = TaskFinishedCommand(command_id=command_id) - self.command_buffer.append(command) - - return + command = TaskFinished(finished_command_id=command_id) + await self._send(command) + del self._chat_completion_queues[command_id] async def _trigger_notify_user_to_download_model(self, model_id: str) -> None: logger.warning( @@ -210,6 +221,7 @@ class API: payload.model = model_meta.model_id # Preprocess messages for GPT-OSS harmony format if needed + # TODO: This is slop surely we get rid if "gpt-oss" in payload.model.lower(): import re @@ -233,7 +245,7 @@ class API: # Store thinking in the thinking field message.thinking = thinking_match.group(1).strip() - for instance in self.get_state().instances.values(): + for instance in self.state.instances.values(): if instance.shard_assignments.model_id == payload.model: break else: @@ -242,23 +254,22 @@ class API: status_code=404, detail=f"No instance found for model {payload.model}" ) - command = ChatCompletionCommand( + command = ChatCompletion( command_id=CommandId(), - command_type=CommandType.CHAT_COMPLETION, request_params=payload, ) - self.command_buffer.append(command) + await self._send(command) return StreamingResponse( self._generate_chat_stream(command.command_id), media_type="text/plain" ) def _calculate_total_available_memory(self) -> int: """Calculate total available memory across all nodes in bytes.""" - state = self.get_state() total_available = 0 - for node_profile in state.node_profiles.values(): - total_available += node_profile.memory.ram_available + for node in self.state.topology.list_nodes(): + if node.node_profile is not None: + total_available += node.node_profile.memory.ram_available.in_bytes return total_available @@ -277,14 +288,35 @@ class API: ] ) + async def run(self): + uvicorn_config = uvicorn.Config( + self.app, host="0.0.0.0", port=self.port, access_log=False + ) + uvicorn_server = uvicorn.Server(uvicorn_config) -def start_fastapi_server( - command_buffer: List[Command], - global_events: AsyncSQLiteEventStorage, - get_state: Callable[[], State], - host: str = "0.0.0.0", - port: int = 8000, -): - api = API(command_buffer, global_events, get_state) + async with create_task_group() as tg: + self._tg = tg + logger.info("Starting API") + tg.start_soon(uvicorn_server.serve) + tg.start_soon(self._apply_state) + self.command_sender.close() + self.global_event_receiver.close() - uvicorn.run(api.app, host=host, port=port) + async def _apply_state(self): + with self.global_event_receiver as events: + async for event in events: + self.event_buffer.ingest(event.origin_idx, event.tagged_event.c) + for idx, event in self.event_buffer.drain_indexed(): + self.state = apply(self.state, IndexedEvent(event=event, idx=idx)) + if ( + isinstance(event, ChunkGenerated) + and event.command_id in self._chat_completion_queues + ): + self._chat_completion_queues[event.command_id].put_nowait(event) + + async def _send(self, command: Command): + await self.command_sender.send( + ForwarderCommand( + origin=self.node_id, tagged_command=TaggedCommand.from_(command) + ) + ) diff --git a/src/exo/master/election_callback.py b/src/exo/master/election_callback.py deleted file mode 100644 index 0d2ad65c..00000000 --- a/src/exo/master/election_callback.py +++ /dev/null @@ -1,23 +0,0 @@ -from loguru import logger - -from exo.master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor - - -class ElectionCallbacks: - """ - Simple callbacks for the Rust election system to invoke. - No event system involvement - just direct forwarder control. - """ - - def __init__(self, forwarder_supervisor: ForwarderSupervisor): - self._forwarder_supervisor = forwarder_supervisor - - async def on_became_master(self) -> None: - """Called when this node is elected as master""" - logger.info("Node elected as master") - await self._forwarder_supervisor.notify_role_change(ForwarderRole.MASTER) - - async def on_became_replica(self) -> None: - """Called when this node becomes a replica""" - logger.info("Node demoted to replica") - await self._forwarder_supervisor.notify_role_change(ForwarderRole.REPLICA) diff --git a/src/exo/master/env.py b/src/exo/master/env.py deleted file mode 100644 index 3b703d93..00000000 --- a/src/exo/master/env.py +++ /dev/null @@ -1,9 +0,0 @@ -from pathlib import Path - -from exo.shared.env import BaseEnv - - -class MasterEnvironmentSchema(BaseEnv): - # Master-specific: forwarder configuration - # Default to build/forwarder if not explicitly set - FORWARDER_BINARY_PATH: Path = Path("build/forwarder") diff --git a/src/exo/master/forwarder_supervisor.py b/src/exo/master/forwarder_supervisor.py index 1ff87d5d..f4f4e5b1 100644 --- a/src/exo/master/forwarder_supervisor.py +++ b/src/exo/master/forwarder_supervisor.py @@ -10,7 +10,7 @@ from exo.shared.constants import ( EXO_GLOBAL_EVENT_DB, EXO_WORKER_EVENT_DB, LIBP2P_GLOBAL_EVENTS_TOPIC, - LIBP2P_WORKER_EVENTS_TOPIC, + LIBP2P_LOCAL_EVENTS_TOPIC, ) from exo.shared.types.common import NodeId @@ -58,9 +58,7 @@ class ForwarderSupervisor: if self._current_role == new_role: logger.debug(f"Role unchanged: {new_role}") return - logger.bind(user_facing=True).info( - f"Node changing from {self._current_role} to {new_role}" - ) + logger.info(f"Node changing from {self._current_role} to {new_role}") self._current_role = new_role await self._restart_with_role(new_role) @@ -82,13 +80,13 @@ class ForwarderSupervisor: # Both master and replica forward local worker events to network pairs.append( - f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}" + f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}" ) if role == ForwarderRole.MASTER: # Master: collect worker events from network into global log pairs.append( - f"libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events" + f"libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events" ) # Master: broadcast global events to network pairs.append( diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 18d77c4a..443a2803 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -1,272 +1,240 @@ -import asyncio -import os -import threading -from pathlib import Path - +from anyio import create_task_group +from anyio.abc import TaskGroup from loguru import logger -from exo.master.api import start_fastapi_server -from exo.master.election_callback import ElectionCallbacks -from exo.master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor -from exo.master.placement import get_instance_placements, get_transition_events +from exo.master.placement import ( + get_instance_placements_after_create, + get_instance_placements_after_delete, + get_transition_events, +) from exo.shared.apply import apply -from exo.shared.constants import EXO_MASTER_LOG -from exo.shared.db.sqlite.config import EventLogConfig -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage -from exo.shared.db.sqlite.event_log_manager import EventLogManager -from exo.shared.keypair import Keypair, get_node_id_keypair -from exo.shared.logging import logger_cleanup, logger_setup +from exo.shared.types.commands import ( + ChatCompletion, + CreateInstance, + DeleteInstance, + ForwarderCommand, + RequestEventLog, + SpinUpInstance, + TaskFinished, +) from exo.shared.types.common import CommandId, NodeId from exo.shared.types.events import ( Event, - Heartbeat, + ForwarderEvent, + IndexedEvent, InstanceDeleted, + TaggedEvent, TaskCreated, TaskDeleted, TopologyEdgeDeleted, - TopologyNodeCreated, -) -from exo.shared.types.events.commands import ( - ChatCompletionCommand, - Command, - CreateInstanceCommand, - DeleteInstanceCommand, - TaskFinishedCommand, ) from exo.shared.types.state import State from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType -from exo.shared.types.worker.instances import Instance +from exo.shared.types.worker.common import InstanceId +from exo.utils.channels import Receiver, Sender, channel +from exo.utils.event_buffer import MultiSourceBuffer class Master: def __init__( self, - node_id_keypair: Keypair, node_id: NodeId, - command_buffer: list[Command], - global_events: AsyncSQLiteEventStorage, - worker_events: AsyncSQLiteEventStorage, - forwarder_binary_path: Path, + *, + command_receiver: Receiver[ForwarderCommand], + # Receiving indexed events from the forwarder to be applied to state + # Ideally these would be WorkerForwarderEvents but type system says no :( + local_event_receiver: Receiver[ForwarderEvent], + # Send events to the forwarder to be indexed (usually from command processing) + # Ideally these would be MasterForwarderEvents but type system says no :( + global_event_sender: Sender[ForwarderEvent], + tb_only: bool = False, ): self.state = State() - self.node_id_keypair = node_id_keypair + self._tg: TaskGroup | None = None self.node_id = node_id - self.command_buffer = command_buffer - self.global_events = global_events - self.worker_events = worker_events self.command_task_mapping: dict[CommandId, TaskId] = {} - self.forwarder_supervisor = ForwarderSupervisor( - self.node_id, - forwarder_binary_path=forwarder_binary_path, + self.command_receiver = command_receiver + self.local_event_receiver = local_event_receiver + self.global_event_sender = global_event_sender + send, recv = channel[Event]() + self.event_sender: Sender[Event] = send + self._loopback_event_receiver: Receiver[Event] = recv + self._loopback_event_sender: Sender[ForwarderEvent] = ( + local_event_receiver.clone_sender() ) - self.election_callbacks = ElectionCallbacks(self.forwarder_supervisor) - - @property - def event_log_for_reads(self) -> AsyncSQLiteEventStorage: - return self.global_events - - @property - def event_log_for_writes(self) -> AsyncSQLiteEventStorage: - if self.forwarder_supervisor.current_role == ForwarderRole.MASTER: - return self.global_events - else: - return self.worker_events - - async def _get_state_snapshot(self) -> State: - # TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events - return State() - - async def _run_event_loop_body(self) -> None: - next_events: list[Event] = [] - # 1. process commands - if ( - self.forwarder_supervisor.current_role == ForwarderRole.MASTER - and len(self.command_buffer) > 0 - ): - # for now we do one command at a time - next_command = self.command_buffer.pop(0) - - logger.bind(user_facing=True).info(f"Executing command: {next_command}") - logger.info(f"Got command: {next_command}") - - # TODO: validate the command - match next_command: - case ChatCompletionCommand(): - matching_instance: Instance | None = None - for instance in self.state.instances.values(): - if ( - instance.shard_assignments.model_id - == next_command.request_params.model - ): - matching_instance = instance - break - if not matching_instance: - raise ValueError( - f"No instance found for model {next_command.request_params.model}" - ) - - task_id = TaskId() - next_events.append( - TaskCreated( - task_id=task_id, - task=ChatCompletionTask( - task_type=TaskType.CHAT_COMPLETION, - task_id=task_id, - command_id=next_command.command_id, - instance_id=matching_instance.instance_id, - task_status=TaskStatus.PENDING, - task_params=next_command.request_params, - ), - ) - ) - - self.command_task_mapping[next_command.command_id] = task_id - case DeleteInstanceCommand(): - placement = get_instance_placements( - next_command, self.state.topology, self.state.instances - ) - transition_events = get_transition_events( - self.state.instances, placement - ) - next_events.extend(transition_events) - case CreateInstanceCommand(): - placement = get_instance_placements( - next_command, self.state.topology, self.state.instances - ) - transition_events = get_transition_events( - self.state.instances, placement - ) - next_events.extend(transition_events) - case TaskFinishedCommand(): - next_events.append( - TaskDeleted( - task_id=self.command_task_mapping[next_command.command_id] - ) - ) - del self.command_task_mapping[next_command.command_id] - - await self.event_log_for_writes.append_events( - next_events, origin=self.node_id - ) - # 2. get latest events - events = await self.event_log_for_reads.get_events_since( - self.state.last_event_applied_idx, ignore_no_op_events=True - ) - if len(events) == 0: - await asyncio.sleep(0.01) - return - - if len(events) == 1: - logger.debug(f"Master received event: {events[0]}") - else: - logger.debug(f"Master received events: {events}") - - # 3. for each event, apply it to the state - for event_from_log in events: - logger.trace(f"Applying event: {event_from_log}") - self.state = apply(self.state, event_from_log) - logger.trace(f"State: {self.state.model_dump_json()}") - - # TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken. - write_events: list[Event] = [] - if any( - [ - isinstance(event_from_log.event, TopologyEdgeDeleted) - for event_from_log in events - ] - ): - connected_node_ids = set( - [x.node_id for x in self.state.topology.list_nodes()] - ) - for instance_id, instance in self.state.instances.items(): - delete = False - for node_id in instance.shard_assignments.node_to_runner: - if node_id not in connected_node_ids: - delete = True - break - if delete: - write_events.append(InstanceDeleted(instance_id=instance_id)) - - if write_events: - await self.event_log_for_writes.append_events( - events=write_events, origin=self.node_id - ) + self._multi_buffer = MultiSourceBuffer[NodeId, Event]() + # TODO: not have this + self._event_log: list[Event] = [] + self.tb_only = tb_only async def run(self): - self.state = await self._get_state_snapshot() + logger.info("Starting Master") - async def heartbeat_task(): - while True: - await self.event_log_for_writes.append_events( - [Heartbeat(node_id=self.node_id)], origin=self.node_id + async with create_task_group() as tg: + self._tg = tg + tg.start_soon(self._event_processor) + tg.start_soon(self._command_processor) + tg.start_soon(self._loopback_processor) + self.global_event_sender.close() + self.local_event_receiver.close() + self.command_receiver.close() + self._loopback_event_sender.close() + self._loopback_event_receiver.close() + + async def shutdown(self): + if self._tg: + logger.info("Stopping Master") + self._tg.cancel_scope.cancel() + + async def _command_processor(self) -> None: + with self.command_receiver as commands: + async for forwarder_command in commands: + try: + logger.info( + f"Executing command: {forwarder_command.tagged_command.c}" + ) + generated_events: list[Event] = [] + command = forwarder_command.tagged_command.c + match command: + case ChatCompletion(): + instance_task_counts: dict[InstanceId, int] = {} + for instance in self.state.instances.values(): + if ( + instance.shard_assignments.model_id + == command.request_params.model + ): + task_count = sum( + 1 + for task in self.state.tasks.values() + if task.instance_id == instance.instance_id + ) + instance_task_counts[instance.instance_id] = ( + task_count + ) + + if not instance_task_counts: + logger.warning( + f"No instance found for model {command.request_params.model}" + ) + continue + + available_instance_ids = sorted( + instance_task_counts.keys(), + key=lambda instance_id: instance_task_counts[ + instance_id + ], + ) + + task_id = TaskId() + generated_events.append( + TaskCreated( + task_id=task_id, + task=ChatCompletionTask( + task_type=TaskType.CHAT_COMPLETION, + task_id=task_id, + command_id=command.command_id, + instance_id=available_instance_ids[0], + task_status=TaskStatus.PENDING, + task_params=command.request_params, + ), + ) + ) + + self.command_task_mapping[command.command_id] = task_id + case DeleteInstance(): + placement = get_instance_placements_after_delete( + command, self.state.instances + ) + transition_events = get_transition_events( + self.state.instances, placement + ) + generated_events.extend(transition_events) + case CreateInstance(): + placement = get_instance_placements_after_create( + command, + self.state.topology, + self.state.instances, + tb_only=self.tb_only, + ) + transition_events = get_transition_events( + self.state.instances, placement + ) + generated_events.extend(transition_events) + case TaskFinished(): + generated_events.append( + TaskDeleted( + task_id=self.command_task_mapping[ + command.finished_command_id + ] + ) + ) + if command.finished_command_id in self.command_task_mapping: + del self.command_task_mapping[ + command.finished_command_id + ] + case SpinUpInstance(): + raise NotImplementedError + case RequestEventLog(): + # We should just be able to send everything, since other buffers will ignore old messages + for i in range(command.since_idx, len(self._event_log)): + await self._send_event( + IndexedEvent(idx=i, event=self._event_log[i]) + ) + for event in generated_events: + await self.event_sender.send(event) + except Exception as e: + logger.opt(exception=e).warning("Error in command processor") + + async def _event_processor(self) -> None: + with self.local_event_receiver as local_events: + async for local_event in local_events: + self._multi_buffer.ingest( + local_event.origin_idx, + local_event.tagged_event.c, + local_event.origin, ) - await asyncio.sleep(5) + for event in self._multi_buffer.drain(): + logger.debug(f"Master indexing event: {str(event)[:100]}") + indexed = IndexedEvent(event=event, idx=len(self._event_log)) + self.state = apply(self.state, indexed) + # TODO: SQL + self._event_log.append(event) + await self._send_event(indexed) - asyncio.create_task(heartbeat_task()) + # TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken. + if isinstance(event, TopologyEdgeDeleted): + connected_node_ids = set( + [x.node_id for x in self.state.topology.list_nodes()] + ) + for instance_id, instance in self.state.instances.items(): + for node_id in instance.shard_assignments.node_to_runner: + if node_id not in connected_node_ids: + await self.event_sender.send( + InstanceDeleted(instance_id=instance_id) + ) + break - # TODO: we should clean these up on shutdown - await self.forwarder_supervisor.start_as_replica() - if os.getenv("EXO_RUN_AS_REPLICA") in set(["TRUE", "true", "1"]): - await self.election_callbacks.on_became_replica() - else: - await self.election_callbacks.on_became_master() + async def _loopback_processor(self) -> None: + # this would ideally not be necessary. + # this is WAY less hacky than how I was working around this before + local_index = 0 + with self._loopback_event_receiver as events: + async for event in events: + await self._loopback_event_sender.send( + ForwarderEvent( + origin=NodeId(f"master_{self.node_id}"), + origin_idx=local_index, + tagged_event=TaggedEvent.from_(event), + ) + ) + local_index += 1 - role = ( - "MASTER" - if self.forwarder_supervisor.current_role == ForwarderRole.MASTER - else "REPLICA" + async def _send_event(self, event: IndexedEvent): + # Convenience method since this line is ugly + await self.global_event_sender.send( + ForwarderEvent( + origin=self.node_id, + origin_idx=event.idx, + tagged_event=TaggedEvent.from_(event.event), + ) ) - await self.event_log_for_writes.append_events( - [TopologyNodeCreated(node_id=self.node_id, role=role)], origin=self.node_id - ) - while True: - try: - await self._run_event_loop_body() - except Exception as e: - logger.opt(exception=e).error(f"Error in _run_event_loop_body: {e}") - await asyncio.sleep(0.1) - - -async def async_main(): - node_id_keypair = get_node_id_keypair() - node_id = NodeId(node_id_keypair.to_peer_id().to_base58()) - - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - global_events: AsyncSQLiteEventStorage = event_log_manager.global_events - worker_events: AsyncSQLiteEventStorage = event_log_manager.worker_events - - command_buffer: list[Command] = [] - - logger.info("Starting EXO Master") - logger.info(f"Starting Master with node_id: {node_id}") - - api_port = int(os.environ.get("API_PORT", 8000)) - - api_thread = threading.Thread( - target=start_fastapi_server, - args=(command_buffer, global_events, lambda: master.state, "0.0.0.0", api_port), - daemon=True, - ) - api_thread.start() - logger.bind(user_facing=True).info(f"Dashboard started on port {api_port}.") - - master = Master( - node_id_keypair, - node_id, - command_buffer, - global_events, - worker_events, - Path(os.environ["GO_BUILD_DIR"]) / "forwarder", - ) - await master.run() - logger_cleanup() # pyright: ignore[reportUnreachable] - - -def main(logfile: Path = EXO_MASTER_LOG, verbosity: int = 1): - logger_setup(logfile, verbosity) - asyncio.run(async_main()) - - -if __name__ == "__main__": - main() diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index f61da749..e3884d53 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -1,22 +1,22 @@ import random from collections.abc import Mapping from copy import deepcopy -from functools import singledispatch from typing import Sequence -from exo.master.utils.placement_utils import ( +from exo.master.placement_utils import ( filter_cycles_by_memory, get_hosts_from_subgraph, get_shard_assignments, get_smallest_cycles, ) from exo.shared.topology import Topology +from exo.shared.types.commands import ( + CreateInstance, + DeleteInstance, +) from exo.shared.types.common import Host from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted -from exo.shared.types.events.commands import ( - CreateInstanceCommand, - DeleteInstanceCommand, -) +from exo.shared.types.memory import Memory from exo.shared.types.worker.common import InstanceId from exo.shared.types.worker.instances import Instance, InstanceStatus @@ -25,26 +25,24 @@ def random_ephemeral_port() -> int: return random.randint(49152, 65535) -@singledispatch -def get_instance_placements( - command: CreateInstanceCommand, +def get_instance_placements_after_create( + command: CreateInstance, topology: Topology, - current_instances: dict[InstanceId, Instance], + current_instances: Mapping[InstanceId, Instance], + *, + tb_only: bool = False, ) -> dict[InstanceId, Instance]: - available_models = [ - current_instances[instance].shard_assignments.model_id - for instance in current_instances - ] - if command.model_meta.model_id in available_models: - raise ValueError(f"Instance for {command.model_meta.model_id} already exists") - all_nodes = list(topology.list_nodes()) - cycles = topology.get_cycles() + from loguru import logger + + logger.info("finding cycles:") + cycles = topology.get_cycles_tb() + logger.info(f"{cycles=}") # we can also always just have a node on its own singleton_cycles = [[node] for node in all_nodes] candidate_cycles = cycles + singleton_cycles cycles_with_sufficient_memory = filter_cycles_by_memory( - candidate_cycles, command.model_meta.storage_size_kilobytes * 1024 + candidate_cycles, command.model_meta.storage_size ) if not cycles_with_sufficient_memory: raise ValueError("No cycles found with sufficient memory") @@ -52,25 +50,27 @@ def get_instance_placements( smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory) selected_cycle = None - has_thunderbolt_cycle = any( - [ - topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle) - for cycle in smallest_cycles - ] - ) - if has_thunderbolt_cycle: - smallest_cycles = [ - cycle - for cycle in smallest_cycles - if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle) - ] + smallest_tb_cycles = [ + cycle + for cycle in smallest_cycles + if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle) + ] + + if tb_only and smallest_tb_cycles == []: + raise ValueError("No cycles found with sufficient memory") + + elif smallest_tb_cycles != []: + smallest_cycles = smallest_tb_cycles selected_cycle = max( smallest_cycles, key=lambda cycle: sum( - node.node_profile.memory.ram_available - for node in cycle - if node.node_profile is not None + ( + node.node_profile.memory.ram_available + for node in cycle + if node.node_profile is not None + ), + start=Memory(), ), ) @@ -79,8 +79,8 @@ def get_instance_placements( cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle) hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph) - instance_id = command.instance_id - target_instances = deepcopy(current_instances) + instance_id = InstanceId() + target_instances = dict(deepcopy(current_instances)) target_instances[instance_id] = Instance( instance_id=instance_id, instance_type=InstanceStatus.ACTIVE, @@ -88,6 +88,9 @@ def get_instance_placements( hosts=[ Host( ip=host.ip, + # NOTE: this is stupid + # | + # v # NOTE: it's fine to have non-deterministic ports here since this is in a command decision port=random_ephemeral_port(), ) @@ -97,13 +100,11 @@ def get_instance_placements( return target_instances -@get_instance_placements.register -def _( - command: DeleteInstanceCommand, - topology: Topology, - current_instances: dict[InstanceId, Instance], +def get_instance_placements_after_delete( + command: DeleteInstance, + current_instances: Mapping[InstanceId, Instance], ) -> dict[InstanceId, Instance]: - target_instances = deepcopy(current_instances) + target_instances = dict(deepcopy(current_instances)) if command.instance_id in target_instances: del target_instances[command.instance_id] return target_instances diff --git a/src/exo/master/utils/placement_utils.py b/src/exo/master/placement_utils.py similarity index 77% rename from src/exo/master/utils/placement_utils.py rename to src/exo/master/placement_utils.py index b89736b1..16be2a0c 100644 --- a/src/exo/master/utils/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -4,9 +4,10 @@ from pydantic import BaseModel from exo.shared.topology import Topology from exo.shared.types.common import Host, NodeId +from exo.shared.types.memory import Memory from exo.shared.types.models import ModelMetadata from exo.shared.types.profiling import NodePerformanceProfile -from exo.shared.types.topology import Node +from exo.shared.types.topology import NodeInfo from exo.shared.types.worker.common import RunnerId from exo.shared.types.worker.runners import ShardAssignments from exo.shared.types.worker.shards import PipelineShardMetadata @@ -17,38 +18,41 @@ class NodeWithProfile(BaseModel): node_profile: NodePerformanceProfile -def narrow_all_nodes(nodes: list[Node]) -> TypeGuard[list[NodeWithProfile]]: +def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]: return all(node.node_profile is not None for node in nodes) def filter_cycles_by_memory( - cycles: list[list[Node]], required_memory: int -) -> list[list[Node]]: - filtered_cycles: list[list[Node]] = [] + cycles: list[list[NodeInfo]], required_memory: Memory +) -> list[list[NodeInfo]]: + filtered_cycles: list[list[NodeInfo]] = [] for cycle in cycles: if not narrow_all_nodes(cycle): continue - total_mem = sum(node.node_profile.memory.ram_available for node in cycle) + total_mem = sum( + (node.node_profile.memory.ram_available for node in cycle), start=Memory() + ) if total_mem >= required_memory: - filtered_cycles.append(cast(list[Node], cycle)) + filtered_cycles.append(cast(list[NodeInfo], cycle)) return filtered_cycles -def get_smallest_cycles(cycles: list[list[Node]]) -> list[list[Node]]: +def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]: min_nodes = min(len(cycle) for cycle in cycles) return [cycle for cycle in cycles if len(cycle) == min_nodes] def get_shard_assignments( model_meta: ModelMetadata, - selected_cycle: list[Node], + selected_cycle: list[NodeInfo], ) -> ShardAssignments: if not narrow_all_nodes(selected_cycle): raise ValueError("All nodes must have profiles to create shard assignments") cycle_memory = sum( - node.node_profile.memory.ram_available for node in selected_cycle + (node.node_profile.memory.ram_available for node in selected_cycle), + start=Memory(), ) total_layers = model_meta.n_layers runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {} @@ -60,7 +64,11 @@ def get_shard_assignments( node_layers = total_layers - layers_assigned else: node_layers = round( - total_layers * (node.node_profile.memory.ram_available / cycle_memory) + total_layers + * ( + node.node_profile.memory.ram_available.in_bytes + / cycle_memory.in_bytes + ) ) node_layers = max(1, node_layers) @@ -109,6 +117,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]: ): if get_thunderbolt and not connection.is_thunderbolt(): continue + assert connection.send_back_multiaddr is not None host = Host( ip=connection.send_back_multiaddr.ip_address, port=connection.send_back_multiaddr.port, diff --git a/src/exo/master/tests/conftest.py b/src/exo/master/tests/conftest.py index fcfaace4..a22333b9 100644 --- a/src/exo/master/tests/conftest.py +++ b/src/exo/master/tests/conftest.py @@ -1,3 +1,5 @@ +from typing import Callable + import pytest from exo.shared.types.common import NodeId @@ -7,21 +9,21 @@ from exo.shared.types.profiling import ( NodePerformanceProfile, SystemPerformanceProfile, ) -from exo.shared.types.topology import Connection, ConnectionProfile, Node +from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo @pytest.fixture def create_node(): - def _create_node(memory: int, node_id: NodeId | None = None) -> Node: + def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo: if node_id is None: node_id = NodeId() - return Node( + return NodeInfo( node_id=node_id, node_profile=NodePerformanceProfile( model_id="test", chip_id="test", friendly_name="test", - memory=MemoryPerformanceProfile( + memory=MemoryPerformanceProfile.from_bytes( ram_total=1000, ram_available=memory, swap_total=1000, @@ -37,7 +39,7 @@ def create_node(): # TODO: this is a hack to get the port for the send_back_multiaddr @pytest.fixture -def create_connection(): +def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]: port_counter = 1235 def _create_connection( @@ -50,7 +52,6 @@ def create_connection(): return Connection( local_node_id=source_node_id, send_back_node_id=sink_node_id, - local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"), send_back_multiaddr=Multiaddr( address=f"/ip4/127.0.0.1/tcp/{send_back_port}" ), diff --git a/src/exo/master/tests/test_forwarder_supervisor.py b/src/exo/master/tests/test_forwarder_supervisor.py index dabdf5cb..97cb6ec6 100644 --- a/src/exo/master/tests/test_forwarder_supervisor.py +++ b/src/exo/master/tests/test_forwarder_supervisor.py @@ -23,9 +23,8 @@ from exo.shared.constants import ( EXO_GLOBAL_EVENT_DB, EXO_WORKER_EVENT_DB, LIBP2P_GLOBAL_EVENTS_TOPIC, - LIBP2P_WORKER_EVENTS_TOPIC, + LIBP2P_LOCAL_EVENTS_TOPIC, ) -from exo.shared.logging import logger_test_install from exo.shared.types.common import NodeId # Mock forwarder script content @@ -192,7 +191,6 @@ class TestForwardersupervisorBasic: ], ) -> None: """Test starting forwarder in replica mode.""" - logger_test_install(test_logger) # Set environment os.environ.update(mock_env_vars) @@ -216,7 +214,7 @@ class TestForwardersupervisorBasic: # Expected replica forwarding pairs expected_pairs = [ - f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}", + f"sqlite:{EXO_WORKER_EVENT_DB}:events|libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}", f"libp2p:{LIBP2P_GLOBAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events", ] @@ -238,7 +236,6 @@ class TestForwardersupervisorBasic: ], ) -> None: """Test changing role from replica to master.""" - logger_test_install(test_logger) os.environ.update(mock_env_vars) supervisor = ForwarderSupervisor(NodeId(), mock_forwarder_script) @@ -265,7 +262,7 @@ class TestForwardersupervisorBasic: # Expected master forwarding pairs master_pairs = [ - f"libp2p:{LIBP2P_WORKER_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events", + f"libp2p:{LIBP2P_LOCAL_EVENTS_TOPIC}|sqlite:{EXO_GLOBAL_EVENT_DB}:events", f"sqlite:{EXO_GLOBAL_EVENT_DB}:events|libp2p:{LIBP2P_GLOBAL_EVENTS_TOPIC}", ] @@ -285,7 +282,6 @@ class TestForwardersupervisorBasic: ], ) -> None: """Test that setting the same role twice doesn't restart the process.""" - logger_test_install(test_logger) os.environ.update(mock_env_vars) supervisor = ForwarderSupervisor(NodeId(), mock_forwarder_script) @@ -316,7 +312,6 @@ class TestForwardersupervisorBasic: ], ) -> None: """Test that Forwardersupervisor restarts the process if it crashes.""" - logger_test_install(test_logger) # Configure mock to exit after 1 second mock_env_vars["MOCK_EXIT_AFTER"] = "1" mock_env_vars["MOCK_EXIT_CODE"] = "1" @@ -365,7 +360,6 @@ class TestForwardersupervisorBasic: self, test_logger: logging.Logger, temp_dir: Path ) -> None: """Test behavior when forwarder binary doesn't exist.""" - logger_test_install(test_logger) nonexistent_path = temp_dir / "nonexistent_forwarder" supervisor = ForwarderSupervisor(NodeId(), nonexistent_path) @@ -381,7 +375,6 @@ class TestElectionCallbacks: @pytest.mark.asyncio async def test_on_became_master(self, test_logger: logging.Logger) -> None: """Test callback when becoming master.""" - logger_test_install(test_logger) mock_supervisor = MagicMock(spec=ForwarderSupervisor) mock_supervisor.notify_role_change = AsyncMock() @@ -393,7 +386,6 @@ class TestElectionCallbacks: @pytest.mark.asyncio async def test_on_became_replica(self, test_logger: logging.Logger) -> None: """Test callback when becoming replica.""" - logger_test_install(test_logger) mock_supervisor = MagicMock(spec=ForwarderSupervisor) mock_supervisor.notify_role_change = AsyncMock() diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index cc0c02ad..b93f2bb7 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -1,30 +1,29 @@ import asyncio import tempfile -from logging import Logger from pathlib import Path from typing import List, Sequence import pytest from exo.master.main import Master -from exo.shared.db.sqlite.config import EventLogConfig -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage -from exo.shared.db.sqlite.event_log_manager import EventLogManager +from exo.shared.db.config import EventLogConfig +from exo.shared.db.connector import AsyncSQLiteEventStorage +from exo.shared.db.event_log_manager import EventLogManager from exo.shared.keypair import Keypair -from exo.shared.logging import logger_test_install from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams -from exo.shared.types.common import NodeId -from exo.shared.types.events import Event, EventFromEventLog, Heartbeat, TaskCreated -from exo.shared.types.events._events import ( - InstanceCreated, - NodePerformanceMeasured, - TopologyNodeCreated, -) -from exo.shared.types.events.commands import ( - ChatCompletionCommand, +from exo.shared.types.commands import ( + ChatCompletion, Command, CommandId, - CreateInstanceCommand, + CreateInstance, +) +from exo.shared.types.common import NodeId +from exo.shared.types.events import ( + IndexedEvent, + InstanceCreated, + NodePerformanceMeasured, + TaskCreated, + TopologyNodeCreated, ) from exo.shared.types.models import ModelMetadata from exo.shared.types.profiling import ( @@ -35,7 +34,6 @@ from exo.shared.types.profiling import ( from exo.shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType from exo.shared.types.worker.instances import ( Instance, - InstanceId, InstanceStatus, ShardAssignments, ) @@ -43,7 +41,7 @@ from exo.shared.types.worker.shards import PartitionStrategy, PipelineShardMetad def _create_forwarder_dummy_binary() -> Path: - path = Path(tempfile.mktemp()) / "forwarder.bin" + path = Path(tempfile.mkstemp()[1]) / "forwarder.bin" if not path.exists(): path.parent.mkdir(parents=True, exist_ok=True) path.write_bytes(b"#!/bin/sh\necho dummy forwarder && sleep 1000000\n") @@ -53,23 +51,20 @@ def _create_forwarder_dummy_binary() -> Path: @pytest.mark.asyncio async def test_master(): - logger = Logger(name="test_master_logger") - logger_test_install(logger) event_log_manager = EventLogManager(EventLogConfig()) await event_log_manager.initialize() global_events: AsyncSQLiteEventStorage = event_log_manager.global_events await global_events.delete_all_events() - async def _get_events() -> Sequence[EventFromEventLog[Event]]: + async def _get_events() -> Sequence[IndexedEvent]: orig_events = await global_events.get_events_since(0) override_idx_in_log = 1 - events: List[EventFromEventLog[Event]] = [] + events: List[IndexedEvent] = [] for e in orig_events: - if isinstance(e.event, Heartbeat): - continue events.append( - EventFromEventLog( - event=e.event, origin=e.origin, idx_in_log=override_idx_in_log + IndexedEvent( + event=e.event, + idx=override_idx_in_log, # origin=e.origin, ) ) override_idx_in_log += 1 @@ -120,9 +115,8 @@ async def test_master(): await asyncio.sleep(0.001) command_buffer.append( - CreateInstanceCommand( + CreateInstance( command_id=CommandId(), - instance_id=InstanceId(), model_meta=ModelMetadata( model_id="llama-3.2-1b", pretty_name="Llama 3.2 1B", @@ -134,7 +128,7 @@ async def test_master(): while len(master.state.instances.keys()) == 0: await asyncio.sleep(0.001) command_buffer.append( - ChatCompletionCommand( + ChatCompletion( command_id=CommandId(), request_params=ChatCompletionTaskParams( model="llama-3.2-1b", @@ -150,7 +144,7 @@ async def test_master(): events = await _get_events() print(events) assert len(events) == 4 - assert events[0].idx_in_log == 1 + assert events[0].idx == 1 assert isinstance(events[0].event, TopologyNodeCreated) assert isinstance(events[1].event, NodePerformanceMeasured) assert isinstance(events[2].event, InstanceCreated) diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 4f83fcfa..16a33200 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -2,15 +2,17 @@ from typing import Callable import pytest -from exo.master.placement import get_instance_placements, get_transition_events -from exo.shared.topology import Topology -from exo.shared.types.common import CommandId, NodeId -from exo.shared.types.events._events import ( - _EventType, # pyright: ignore[reportPrivateUsage] +from exo.master.placement import ( + get_instance_placements_after_create, + get_transition_events, ) -from exo.shared.types.events.commands import CreateInstanceCommand -from exo.shared.types.models import ModelMetadata -from exo.shared.types.topology import Connection, Node +from exo.shared.topology import Topology +from exo.shared.types.commands import CreateInstance +from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.events import InstanceCreated, InstanceDeleted +from exo.shared.types.memory import Memory +from exo.shared.types.models import ModelId, ModelMetadata +from exo.shared.types.topology import Connection, NodeInfo from exo.shared.types.worker.common import InstanceId from exo.shared.types.worker.instances import Instance, InstanceStatus from exo.shared.types.worker.runners import ShardAssignments @@ -27,7 +29,7 @@ def instance() -> Instance: instance_id=InstanceId(), instance_type=InstanceStatus.ACTIVE, shard_assignments=ShardAssignments( - model_id="test-model", runner_to_shard={}, node_to_runner={} + model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={} ), hosts=[], ) @@ -36,18 +38,17 @@ def instance() -> Instance: @pytest.fixture def model_meta() -> ModelMetadata: return ModelMetadata( - model_id="test-model", - storage_size_kilobytes=1000, + model_id=ModelId("test-model"), + storage_size=Memory.from_kb(1000), pretty_name="Test Model", n_layers=10, ) -def create_instance_command(model_meta: ModelMetadata) -> CreateInstanceCommand: - return CreateInstanceCommand( +def create_instance_command(model_meta: ModelMetadata) -> CreateInstance: + return CreateInstance( command_id=CommandId(), model_meta=model_meta, - instance_id=InstanceId(), ) @@ -65,32 +66,33 @@ def test_get_instance_placements_create_instance( expected_layers: tuple[int, int, int], topology: Topology, model_meta: ModelMetadata, - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[Memory, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId], Connection], ): # arrange model_meta.n_layers = total_layers - model_meta.storage_size_kilobytes = sum( + model_meta.storage_size.in_bytes = sum( available_memory ) # make it exactly fit across all nodes - create_instance_command = CreateInstanceCommand( + create_instance_command = CreateInstance( command_id=CommandId(), model_meta=model_meta, - instance_id=InstanceId(), ) node_id_a = NodeId() node_id_b = NodeId() node_id_c = NodeId() - topology.add_node(create_node(available_memory[0] * 1024, node_id_a)) - topology.add_node(create_node(available_memory[1] * 1024, node_id_b)) - topology.add_node(create_node(available_memory[2] * 1024, node_id_c)) + topology.add_node(create_node(Memory.from_bytes(available_memory[0]), node_id_a)) + topology.add_node(create_node(Memory.from_bytes(available_memory[1]), node_id_b)) + topology.add_node(create_node(Memory.from_bytes(available_memory[2]), node_id_c)) topology.add_connection(create_connection(node_id_a, node_id_b)) topology.add_connection(create_connection(node_id_b, node_id_c)) topology.add_connection(create_connection(node_id_c, node_id_a)) # act - placements = get_instance_placements(create_instance_command, topology, {}) + placements = get_instance_placements_after_create( + create_instance_command, topology, {} + ) # assert assert len(placements) == 1 @@ -117,22 +119,23 @@ def test_get_instance_placements_create_instance( def test_get_instance_placements_one_node_exact_fit( - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], ) -> None: topology = Topology() node_id = NodeId() topology.add_node(create_node(1000 * 1024, node_id)) - create_instance_command = CreateInstanceCommand( + create_instance_command = CreateInstance( command_id=CommandId(), model_meta=ModelMetadata( - model_id="test-model", - storage_size_kilobytes=1000, + model_id=ModelId("test-model"), + storage_size=Memory.from_kb(1000), pretty_name="Test Model", n_layers=10, ), - instance_id=InstanceId(), ) - placements = get_instance_placements(create_instance_command, topology, {}) + placements = get_instance_placements_after_create( + create_instance_command, topology, {} + ) assert len(placements) == 1 instance_id = list(placements.keys())[0] @@ -144,22 +147,23 @@ def test_get_instance_placements_one_node_exact_fit( def test_get_instance_placements_one_node_fits_with_extra_memory( - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], ) -> None: topology = Topology() node_id = NodeId() topology.add_node(create_node(1001 * 1024, node_id)) - create_instance_command = CreateInstanceCommand( + create_instance_command = CreateInstance( command_id=CommandId(), model_meta=ModelMetadata( - model_id="test-model", - storage_size_kilobytes=1000, + model_id=ModelId("test-model"), + storage_size=Memory.from_kb(1000), pretty_name="Test Model", n_layers=10, ), - instance_id=InstanceId(), ) - placements = get_instance_placements(create_instance_command, topology, {}) + placements = get_instance_placements_after_create( + create_instance_command, topology, {} + ) assert len(placements) == 1 instance_id = list(placements.keys())[0] @@ -171,27 +175,26 @@ def test_get_instance_placements_one_node_fits_with_extra_memory( def test_get_instance_placements_one_node_not_fit( - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], ) -> None: topology = Topology() node_id = NodeId() topology.add_node(create_node(1000 * 1024, node_id)) - create_instance_command = CreateInstanceCommand( + create_instance_command = CreateInstance( command_id=CommandId(), model_meta=ModelMetadata( - model_id="test-model", - storage_size_kilobytes=1001, + model_id=ModelId("test-model"), + storage_size=Memory.from_kb(1001), pretty_name="Test Model", n_layers=10, ), - instance_id=InstanceId(), ) with pytest.raises(ValueError, match="No cycles found with sufficient memory"): - get_instance_placements(create_instance_command, topology, {}) + get_instance_placements_after_create(create_instance_command, topology, {}) -def test_get_transition_events_no_change(topology: Topology, instance: Instance): +def test_get_transition_events_no_change(instance: Instance): # arrange instance_id = InstanceId() current_instances = {instance_id: instance} @@ -204,7 +207,7 @@ def test_get_transition_events_no_change(topology: Topology, instance: Instance) assert len(events) == 0 -def test_get_transition_events_create_instance(topology: Topology, instance: Instance): +def test_get_transition_events_create_instance(instance: Instance): # arrange instance_id = InstanceId() current_instances: dict[InstanceId, Instance] = {} @@ -215,10 +218,10 @@ def test_get_transition_events_create_instance(topology: Topology, instance: Ins # assert assert len(events) == 1 - assert events[0].event_type == _EventType.InstanceCreated + assert isinstance(events[0], InstanceCreated) -def test_get_transition_events_delete_instance(topology: Topology, instance: Instance): +def test_get_transition_events_delete_instance(instance: Instance): # arrange instance_id = InstanceId() current_instances: dict[InstanceId, Instance] = {instance_id: instance} @@ -229,5 +232,5 @@ def test_get_transition_events_delete_instance(topology: Topology, instance: Ins # assert assert len(events) == 1 - assert events[0].event_type == _EventType.InstanceDeleted + assert isinstance(events[0], InstanceDeleted) assert events[0].instance_id == instance_id diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index ed1dadc2..31796a36 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -1,9 +1,8 @@ -from ipaddress import IPv4Address from typing import Callable import pytest -from exo.master.utils.placement_utils import ( +from exo.master.placement_utils import ( filter_cycles_by_memory, get_hosts_from_subgraph, get_shard_assignments, @@ -11,8 +10,9 @@ from exo.master.utils.placement_utils import ( ) from exo.shared.topology import Topology from exo.shared.types.common import Host, NodeId -from exo.shared.types.models import ModelMetadata -from exo.shared.types.topology import Connection, Node +from exo.shared.types.memory import Memory +from exo.shared.types.models import ModelId, ModelMetadata +from exo.shared.types.topology import Connection, NodeInfo @pytest.fixture @@ -23,7 +23,7 @@ def topology() -> Topology: def test_filter_cycles_by_memory( topology: Topology, - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId], Connection], ): # arrange @@ -47,7 +47,7 @@ def test_filter_cycles_by_memory( assert len(cycles[0]) == 2 # act - filtered_cycles = filter_cycles_by_memory(cycles, 1) + filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1)) # assert assert len(filtered_cycles) == 1 @@ -57,7 +57,7 @@ def test_filter_cycles_by_memory( def test_filter_cycles_by_insufficient_memory( topology: Topology, - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId], Connection], ): # arrange @@ -77,7 +77,9 @@ def test_filter_cycles_by_insufficient_memory( topology.add_connection(connection2) # act - filtered_cycles = filter_cycles_by_memory(topology.get_cycles(), 2001 * 1024) + filtered_cycles = filter_cycles_by_memory( + topology.get_cycles(), Memory.from_kb(2001) + ) # assert assert len(filtered_cycles) == 0 @@ -85,7 +87,7 @@ def test_filter_cycles_by_insufficient_memory( def test_filter_multiple_cycles_by_memory( topology: Topology, - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId], Connection], ): # arrange @@ -110,7 +112,7 @@ def test_filter_multiple_cycles_by_memory( cycles = topology.get_cycles() # act - filtered_cycles = filter_cycles_by_memory(cycles, 1500 * 1024) + filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500)) # assert assert len(filtered_cycles) == 1 @@ -124,7 +126,7 @@ def test_filter_multiple_cycles_by_memory( def test_get_smallest_cycles( topology: Topology, - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId], Connection], ): # arrange @@ -164,7 +166,7 @@ def test_get_smallest_cycles( ) def test_get_shard_assignments( topology: Topology, - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId], Connection], available_memory: tuple[int, int, int], total_layers: int, @@ -189,10 +191,10 @@ def test_get_shard_assignments( topology.add_connection(create_connection(node_b_id, node_a_id)) model_meta = ModelMetadata( - model_id="test-model", + model_id=ModelId("test-model"), pretty_name="Test Model", n_layers=total_layers, - storage_size_kilobytes=1000, + storage_size=Memory.from_kb(1000), ) cycles = topology.get_cycles() selected_cycle = cycles[0] @@ -223,7 +225,7 @@ def test_get_shard_assignments( def test_get_hosts_from_subgraph( topology: Topology, - create_node: Callable[[int, NodeId | None], Node], + create_node: Callable[[int, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId, int | None], Connection], ): # arrange @@ -250,9 +252,9 @@ def test_get_hosts_from_subgraph( # assert assert len(hosts) == 3 expected_hosts = [ - Host(ip=IPv4Address("127.0.0.1"), port=5001), - Host(ip=IPv4Address("127.0.0.1"), port=5002), - Host(ip=IPv4Address("127.0.0.1"), port=5003), + Host(ip=("127.0.0.1"), port=5001), + Host(ip=("127.0.0.1"), port=5002), + Host(ip=("127.0.0.1"), port=5003), ] for expected_host in expected_hosts: assert expected_host in hosts diff --git a/src/exo/master/tests/test_topology.py b/src/exo/master/tests/test_topology.py index 18cb84a2..e794c445 100644 --- a/src/exo/master/tests/test_topology.py +++ b/src/exo/master/tests/test_topology.py @@ -7,7 +7,7 @@ from exo.shared.types.profiling import ( NodePerformanceProfile, SystemPerformanceProfile, ) -from exo.shared.types.topology import Connection, ConnectionProfile, Node, NodeId +from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo @pytest.fixture @@ -20,7 +20,6 @@ def connection() -> Connection: return Connection( local_node_id=NodeId(), send_back_node_id=NodeId(), - local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"), send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"), connection_profile=ConnectionProfile( throughput=1000, latency=1000, jitter=1000 @@ -30,7 +29,7 @@ def connection() -> Connection: @pytest.fixture def node_profile() -> NodePerformanceProfile: - memory_profile = MemoryPerformanceProfile( + memory_profile = MemoryPerformanceProfile.from_bytes( ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 ) system_profile = SystemPerformanceProfile(flops_fp16=1000) @@ -54,7 +53,7 @@ def test_add_node(topology: Topology, node_profile: NodePerformanceProfile): node_id = NodeId() # act - topology.add_node(Node(node_id=node_id, node_profile=node_profile)) + topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile)) # assert data = topology.get_node_profile(node_id) @@ -65,9 +64,11 @@ def test_add_connection( topology: Topology, node_profile: NodePerformanceProfile, connection: Connection ): # arrange - topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile)) topology.add_node( - Node(node_id=connection.send_back_node_id, node_profile=node_profile) + NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + ) + topology.add_node( + NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) ) topology.add_connection(connection) @@ -82,9 +83,11 @@ def test_update_node_profile( topology: Topology, node_profile: NodePerformanceProfile, connection: Connection ): # arrange - topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile)) topology.add_node( - Node(node_id=connection.send_back_node_id, node_profile=node_profile) + NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + ) + topology.add_node( + NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) ) topology.add_connection(connection) @@ -92,7 +95,7 @@ def test_update_node_profile( model_id="test", chip_id="test", friendly_name="test", - memory=MemoryPerformanceProfile( + memory=MemoryPerformanceProfile.from_bytes( ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 ), network_interfaces=[], @@ -113,9 +116,11 @@ def test_update_connection_profile( topology: Topology, node_profile: NodePerformanceProfile, connection: Connection ): # arrange - topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile)) topology.add_node( - Node(node_id=connection.send_back_node_id, node_profile=node_profile) + NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + ) + topology.add_node( + NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) ) topology.add_connection(connection) @@ -125,7 +130,6 @@ def test_update_connection_profile( connection = Connection( local_node_id=connection.local_node_id, send_back_node_id=connection.send_back_node_id, - local_multiaddr=connection.local_multiaddr, send_back_multiaddr=connection.send_back_multiaddr, connection_profile=new_connection_profile, ) @@ -142,9 +146,11 @@ def test_remove_connection_still_connected( topology: Topology, node_profile: NodePerformanceProfile, connection: Connection ): # arrange - topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile)) topology.add_node( - Node(node_id=connection.send_back_node_id, node_profile=node_profile) + NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + ) + topology.add_node( + NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) ) topology.add_connection(connection) @@ -155,64 +161,15 @@ def test_remove_connection_still_connected( assert topology.get_connection_profile(connection) is None -def test_remove_connection_bridge( - topology: Topology, node_profile: NodePerformanceProfile, connection: Connection -): - """Create a bridge scenario: master -> node_a -> node_b - and remove the bridge connection (master -> node_a)""" - # arrange - master_id = NodeId() - node_a_id = NodeId() - node_b_id = NodeId() - - topology.add_node(Node(node_id=master_id, node_profile=node_profile)) - topology.add_node(Node(node_id=node_a_id, node_profile=node_profile)) - topology.add_node(Node(node_id=node_b_id, node_profile=node_profile)) - - topology.set_master_node_id(master_id) - - connection_master_to_a = Connection( - local_node_id=master_id, - send_back_node_id=node_a_id, - local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1234"), - send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"), - connection_profile=ConnectionProfile( - throughput=1000, latency=1000, jitter=1000 - ), - ) - - connection_a_to_b = Connection( - local_node_id=node_a_id, - send_back_node_id=node_b_id, - local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1236"), - send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1237"), - connection_profile=ConnectionProfile( - throughput=1000, latency=1000, jitter=1000 - ), - ) - - topology.add_connection(connection_master_to_a) - topology.add_connection(connection_a_to_b) - - assert len(list(topology.list_nodes())) == 3 - - topology.remove_connection(connection_master_to_a) - - remaining_nodes = list(topology.list_nodes()) - assert len(remaining_nodes) == 1 - assert remaining_nodes[0].node_id == master_id - - assert topology.get_node_profile(node_a_id) is None - assert topology.get_node_profile(node_b_id) is None - - def test_remove_node_still_connected( topology: Topology, node_profile: NodePerformanceProfile, connection: Connection ): # arrange - topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile)) topology.add_node( - Node(node_id=connection.send_back_node_id, node_profile=node_profile) + NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + ) + topology.add_node( + NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) ) topology.add_connection(connection) @@ -227,9 +184,11 @@ def test_list_nodes( topology: Topology, node_profile: NodePerformanceProfile, connection: Connection ): # arrange - topology.add_node(Node(node_id=connection.local_node_id, node_profile=node_profile)) topology.add_node( - Node(node_id=connection.send_back_node_id, node_profile=node_profile) + NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + ) + topology.add_node( + NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) ) topology.add_connection(connection) @@ -238,7 +197,7 @@ def test_list_nodes( # assert assert len(nodes) == 2 - assert all(isinstance(node, Node) for node in nodes) + assert all(isinstance(node, NodeInfo) for node in nodes) assert {node.node_id for node in nodes} == { connection.local_node_id, connection.send_back_node_id, diff --git a/src/exo/routing/__init__.py b/src/exo/routing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/exo/routing/connection_message.py b/src/exo/routing/connection_message.py new file mode 100644 index 00000000..665483ac --- /dev/null +++ b/src/exo/routing/connection_message.py @@ -0,0 +1,37 @@ +from enum import Enum + +from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType + +from exo.shared.types.common import NodeId +from exo.utils.pydantic_ext import CamelCaseModel + +"""Serialisable types for Connection Updates/Messages""" + + +class ConnectionMessageType(Enum): + Connected = 0 + Disconnected = 1 + + @staticmethod + def from_update_type(update_type: ConnectionUpdateType): + match update_type: + case ConnectionUpdateType.Connected: + return ConnectionMessageType.Connected + case ConnectionUpdateType.Disconnected: + return ConnectionMessageType.Disconnected + + +class ConnectionMessage(CamelCaseModel): + node_id: NodeId + connection_type: ConnectionMessageType + remote_ipv4: str + remote_tcp_port: int + + @classmethod + def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage": + return cls( + node_id=NodeId(update.peer_id.to_base58()), + connection_type=ConnectionMessageType.from_update_type(update.update_type), + remote_ipv4=update.remote_ipv4, + remote_tcp_port=update.remote_tcp_port, + ) diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py new file mode 100644 index 00000000..cf89e75f --- /dev/null +++ b/src/exo/routing/router.py @@ -0,0 +1,242 @@ +from copy import copy +from itertools import count +from math import inf +from os import PathLike +from pathlib import Path +from typing import cast + +from anyio import ( + BrokenResourceError, + ClosedResourceError, + create_task_group, + sleep_forever, +) +from anyio.abc import TaskGroup +from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError +from filelock import FileLock +from loguru import logger + +from exo.shared.constants import EXO_NODE_ID_KEYPAIR +from exo.utils.channels import Receiver, Sender, channel +from exo.utils.pydantic_ext import CamelCaseModel + +from .connection_message import ConnectionMessage +from .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic + + +# A significant current limitation of the TopicRouter is that it is not capable +# of preventing feedback, as it does not ask for a system id so cannot tell +# which message is coming/going to which system. +# This is currently only relevant for Election +class TopicRouter[T: CamelCaseModel]: + def __init__( + self, + topic: TypedTopic[T], + networking_sender: Sender[tuple[str, bytes]], + max_buffer_size: float = inf, + ): + self.topic: TypedTopic[T] = topic + self.senders: set[Sender[T]] = set() + send, recv = channel[T]() + self.receiver: Receiver[T] = recv + self.temp_sender: Sender[T] | None = send + self.networking_sender: Sender[tuple[str, bytes]] = networking_sender + + async def run(self): + logger.debug(f"Topic Router {self.topic} ready to send") + with self.receiver as items: + async for item in items: + # Check if we should send to network + if ( + len(self.senders) == 0 + and self.topic.publish_policy is PublishPolicy.Minimal + ): + await self._send_out(item) + continue + if self.topic.publish_policy is PublishPolicy.Always: + await self._send_out(item) + # Then publish to all senders + await self.publish(item) + + async def shutdown(self): + logger.debug(f"Shutting down Topic Router {self.topic}") + # Close all the things! + for sender in self.senders: + sender.close() + if self.temp_sender: + self.temp_sender.close() + self.receiver.close() + + async def publish(self, item: T): + """ + Publish item T on this topic to all senders. + NB: this sends to ALL receivers, potentially including receivers held by the object doing the sending. + You should handle your own output if you hold a sender + receiver pair. + """ + to_clear: set[Sender[T]] = set() + for sender in copy(self.senders): + try: + await sender.send(item) + except (ClosedResourceError, BrokenResourceError): + to_clear.add(sender) + self.senders -= to_clear + + async def publish_bytes(self, data: bytes): + await self.publish(self.topic.deserialize(data)) + + async def _send_out(self, item: T): + logger.trace(f"TopicRouter {self.topic.topic} sending {item}") + await self.networking_sender.send( + (str(self.topic.topic), self.topic.serialize(item)) + ) + + +class Router: + @classmethod + def create(cls, identity: Keypair) -> "Router": + return cls(handle=NetworkingHandle(identity)) + + def __init__(self, handle: NetworkingHandle): + self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {} + send, recv = channel[tuple[str, bytes]]() + self.networking_receiver: Receiver[tuple[str, bytes]] = recv + self._net: NetworkingHandle = handle + self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send + self._id_count = count() + self._tg: TaskGroup | None = None + + async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]): + assert self._tg is None, "Attempted to register topic after setup time" + send = self._tmp_networking_sender + if send: + self._tmp_networking_sender = None + else: + send = self.networking_receiver.clone_sender() + router = TopicRouter[T](topic, send) + self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router) + await self._networking_subscribe(str(topic.topic)) + + def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]: + router = self.topic_routers.get(topic.topic, None) + # There's gotta be a way to do this without THIS many asserts + assert router is not None + assert router.topic == topic + send: Sender[T] | None = cast(Sender[T] | None, router.temp_sender) + if send: + router.temp_sender = None + return send + try: + sender = cast(Receiver[T], router.receiver).clone_sender() + except ClosedResourceError: + sender, router.receiver = cast( + tuple[Sender[T], Receiver[CamelCaseModel]], channel[T]() + ) + return sender + + def receiver[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Receiver[T]: + router = self.topic_routers.get(topic.topic, None) + # There's gotta be a way to do this without THIS many asserts + + assert router is not None + assert router.topic == topic + assert router.topic.model_type == topic.model_type + + send, recv = channel[T]() + router.senders.add(cast(Sender[CamelCaseModel], send)) + + return recv + + async def run(self): + logger.debug("Starting Router") + async with create_task_group() as tg: + self._tg = tg + for topic in self.topic_routers: + router = self.topic_routers[topic] + tg.start_soon(router.run) + tg.start_soon(self._networking_recv) + tg.start_soon(self._networking_recv_connection_messages) + tg.start_soon(self._networking_publish) + # Router only shuts down if you cancel it. + await sleep_forever() + for topic in self.topic_routers: + await self._networking_unsubscribe(str(topic)) + + async def shutdown(self): + logger.debug("Shutting down Router") + if not self._tg: + return + self._tg.cancel_scope.cancel() + + async def _networking_subscribe(self, topic: str): + logger.info(f"Subscribing to {topic}") + await self._net.gossipsub_subscribe(topic) + + async def _networking_unsubscribe(self, topic: str): + logger.info(f"Unsubscribing from {topic}") + await self._net.gossipsub_unsubscribe(topic) + + async def _networking_recv(self): + while True: + topic, data = await self._net.gossipsub_recv() + logger.trace(f"Received message on {topic} with payload {data}") + if topic not in self.topic_routers: + logger.warning(f"Received message on unknown or inactive topic {topic}") + continue + + router = self.topic_routers[topic] + await router.publish_bytes(data) + + async def _networking_recv_connection_messages(self): + while True: + update = await self._net.connection_update_recv() + message = ConnectionMessage.from_update(update) + logger.trace( + f"Received message on connection_messages with payload {message}" + ) + if CONNECTION_MESSAGES.topic in self.topic_routers: + router = self.topic_routers[CONNECTION_MESSAGES.topic] + assert router.topic.model_type == ConnectionMessage + router = cast(TopicRouter[ConnectionMessage], router) + await router.publish(message) + + async def _networking_publish(self): + # This with/for pattern ensures this method doesn't return until after the receiver closes + # This is good for safety, but is mostly a redundant check. + with self.networking_receiver as networked_items: + async for topic, data in networked_items: + try: + logger.trace(f"Sending message on {topic} with payload {data}") + await self._net.gossipsub_publish(topic, data) + except NoPeersSubscribedToTopicError: + logger.trace(f"Failed to send over {topic} - No peers found.") + + +def get_node_id_keypair( + path: str | bytes | PathLike[str] | PathLike[bytes] = EXO_NODE_ID_KEYPAIR, +) -> Keypair: + """ + Obtains the :class:`Keypair` associated with this node-ID. + Obtain the :class:`PeerId` by from it. + """ + + def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: + return Path(str(path) + ".lock") + + # operate with cross-process lock to avoid race conditions + with FileLock(lock_path(path)): + with open(path, "a+b") as f: # opens in append-mode => starts at EOF + # if non-zero EOF, then file exists => use to get node-ID + if f.tell() != 0: + f.seek(0) # go to start & read protobuf-encoded bytes + protobuf_encoded = f.read() + + try: # if decoded successfully, save & return + return Keypair.from_protobuf_encoding(protobuf_encoded) + except ValueError as e: # on runtime error, assume corrupt file + logger.warning(f"Encountered error when trying to get keypair: {e}") + + # if no valid credentials, create new ones and persist + with open(path, "w+b") as f: + keypair = Keypair.generate_ed25519() + f.write(keypair.to_protobuf_encoding()) + return keypair diff --git a/src/exo/routing/tests/test_event_buffer.py b/src/exo/routing/tests/test_event_buffer.py new file mode 100644 index 00000000..a6f48a96 --- /dev/null +++ b/src/exo/routing/tests/test_event_buffer.py @@ -0,0 +1,141 @@ +import pytest + +from exo.shared.types.events import Event, TestEvent +from exo.utils.event_buffer import OrderedBuffer + + +def make_indexed_event(idx: int) -> tuple[int, Event]: + """Factory function to create a unique ForwarderEvent for a given index.""" + return (idx, TestEvent()) + + +@pytest.fixture +def buffer() -> OrderedBuffer[Event]: + """Provides a clean instance of OrderedBuffer[Event] for each test.""" + return OrderedBuffer[Event]() + + +@pytest.mark.asyncio +async def test_initial_state(buffer: OrderedBuffer[Event]): + """Tests that a new buffer is empty and starts at index 1.""" + assert buffer.next_idx_to_release == 0 + assert not buffer.store + assert buffer.drain() == [] + + +@pytest.mark.asyncio +async def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]): + """Tests ingesting and draining a simple, ordered sequence of events.""" + events = [make_indexed_event(0), make_indexed_event(1), make_indexed_event(2)] + [buffer.ingest(*ev) for ev in events] + + drained_events = buffer.drain_indexed() + assert drained_events == events + assert buffer.next_idx_to_release == 3 + assert not buffer.store + + +@pytest.mark.asyncio +async def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]): + """Tests that out-of-order events are buffered and drained in the correct sequence.""" + event1 = make_indexed_event(0) + event2 = make_indexed_event(1) + event3 = make_indexed_event(2) + + buffer.ingest(*event3) + buffer.ingest(*event1) + buffer.ingest(*event2) + + drained_events = buffer.drain_indexed() + assert drained_events == [event1, event2, event3] + assert buffer.next_idx_to_release == 3 + + +@pytest.mark.asyncio +async def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]): + """Tests that draining stops when there is a gap in the event indices.""" + event1 = make_indexed_event(0) + event3 = make_indexed_event(2) + + buffer.ingest(*event1) + buffer.ingest(*event3) + + drained_events = buffer.drain_indexed() + assert drained_events == [event1] + assert buffer.next_idx_to_release == 1 + + assert buffer.drain() == [] + assert 2 in buffer.store + + +@pytest.mark.asyncio +async def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]): + """Tests that once a gap is filled, the rest of the sequence is drained.""" + event0 = make_indexed_event(0) + event2 = make_indexed_event(2) + buffer.ingest(*event0) + buffer.ingest(*event2) + + buffer.drain() + assert buffer.next_idx_to_release == 1 + + event1 = make_indexed_event(1) + buffer.ingest(*event1) + + drained_events = buffer.drain_indexed() + assert [e[0] for e in drained_events] == [1, 2] + assert buffer.next_idx_to_release == 3 + + +@pytest.mark.asyncio +async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]): + """Tests that if multiple events for the same index are ingested, the first one wins.""" + event2_first = make_indexed_event(1) + event2_second = (1, TestEvent()) + + buffer.ingest(*make_indexed_event(0)) + buffer.ingest(*event2_first) + buffer.ingest(*event2_second) # This duplicate should be ignored + + drained = buffer.drain_indexed() + assert len(drained) == 2 + + assert drained[1][1].event_id == event2_first[1].event_id + assert drained[1][1].event_id != event2_second[1].event_id + + +@pytest.mark.asyncio +async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]): + """Tests that events with an index lower than next_idx_to_release are dropped.""" + buffer.ingest(*make_indexed_event(0)) + buffer.ingest(*make_indexed_event(1)) + buffer.drain() + + assert buffer.next_idx_to_release == 2 + + stale_event1 = make_indexed_event(0) + stale_event2 = make_indexed_event(1) + buffer.ingest(*stale_event1) + buffer.ingest(*stale_event2) + + assert not buffer.store + assert buffer.drain() == [] + + +@pytest.mark.asyncio +async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]): + """Tests reusing the buffer after it has been fully drained.""" + buffer.ingest(*make_indexed_event(0)) + buffer.ingest(*make_indexed_event(1)) + buffer.drain() + + assert buffer.next_idx_to_release == 2 + assert not buffer.store + + buffer.ingest(*make_indexed_event(4)) + buffer.ingest(*make_indexed_event(2)) + + drained = buffer.drain_indexed() + assert [e[0] for e in drained] == [2] + assert buffer.next_idx_to_release == 3 + assert 4 in buffer.store diff --git a/src/exo/routing/topics.py b/src/exo/routing/topics.py new file mode 100644 index 00000000..50f1c9af --- /dev/null +++ b/src/exo/routing/topics.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass +from enum import Enum + +from exo.routing.connection_message import ConnectionMessage +from exo.shared.election import ElectionMessage +from exo.shared.types.commands import ForwarderCommand +from exo.shared.types.events import ( + ForwarderEvent, +) +from exo.utils.pydantic_ext import CamelCaseModel + + +class PublishPolicy(str, Enum): + Never = "Never" + """Never publish to the network - this is a local message""" + Minimal = "Minimal" + """Only publish when there is no local receiver for this type of message""" + Always = "Always" + """Always publish to the network""" + + +@dataclass # (frozen=True) +class TypedTopic[T: CamelCaseModel]: + topic: str + publish_policy: PublishPolicy + + model_type: type[ + T + ] # This can be worked around with evil type hacking, see https://stackoverflow.com/a/71720366 - I don't think it's necessary here. + + @staticmethod + def serialize(t: T) -> bytes: + return t.model_dump_json().encode("utf-8") + + def deserialize(self, b: bytes) -> T: + return self.model_type.model_validate_json(b.decode("utf-8")) + + +GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent) +LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent) +COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand) +ELECTION_MESSAGES = TypedTopic( + "election_messages", PublishPolicy.Always, ElectionMessage +) +CONNECTION_MESSAGES = TypedTopic( + "connection_messages", PublishPolicy.Never, ConnectionMessage +) diff --git a/src/exo/shared/apply/apply.py b/src/exo/shared/apply.py similarity index 66% rename from src/exo/shared/apply/apply.py rename to src/exo/shared/apply.py index 75c102f4..3c0f2d5d 100644 --- a/src/exo/shared/apply/apply.py +++ b/src/exo/shared/apply.py @@ -1,18 +1,17 @@ -from __future__ import annotations - import copy -from functools import singledispatch from typing import Mapping +from loguru import logger + from exo.shared.types.common import NodeId from exo.shared.types.events import ( + ChunkGenerated, Event, - EventFromEventLog, + IndexedEvent, InstanceActivated, InstanceCreated, InstanceDeactivated, InstanceDeleted, - InstanceReplacedAtomically, NodePerformanceMeasured, RunnerDeleted, RunnerStatusUpdated, @@ -20,48 +19,74 @@ from exo.shared.types.events import ( TaskDeleted, TaskFailed, TaskStateUpdated, + TestEvent, TopologyEdgeCreated, TopologyEdgeDeleted, - TopologyEdgeReplacedAtomically, TopologyNodeCreated, WorkerStatusUpdated, ) from exo.shared.types.profiling import NodePerformanceProfile from exo.shared.types.state import State from exo.shared.types.tasks import Task, TaskId, TaskStatus -from exo.shared.types.topology import Connection, Node -from exo.shared.types.worker.common import NodeStatus, RunnerId +from exo.shared.types.topology import NodeInfo +from exo.shared.types.worker.common import RunnerId, WorkerStatus from exo.shared.types.worker.instances import Instance, InstanceId, InstanceStatus from exo.shared.types.worker.runners import RunnerStatus -@singledispatch def event_apply(event: Event, state: State) -> State: - """Apply an event to *state*. - - Events decorated with ``@no_op_event`` set ``__no_apply__ = True`` on the - class. Such events are considered *no-ops* and therefore leave the state - unchanged without requiring a dedicated handler in this dispatch table. - """ - - if getattr(event, "__no_apply__", False): - return state - - raise RuntimeError(f"no handler registered for event type {type(event).__name__}") + """Apply an event to state.""" + match event: + case TestEvent() | ChunkGenerated(): + return state + case InstanceActivated(): + return apply_instance_activated(event, state) + case InstanceCreated(): + return apply_instance_created(event, state) + case InstanceDeactivated(): + return apply_instance_deactivated(event, state) + case InstanceDeleted(): + return apply_instance_deleted(event, state) + case NodePerformanceMeasured(): + return apply_node_performance_measured(event, state) + case RunnerDeleted(): + return apply_runner_deleted(event, state) + case RunnerStatusUpdated(): + return apply_runner_status_updated(event, state) + case TaskCreated(): + return apply_task_created(event, state) + case TaskDeleted(): + return apply_task_deleted(event, state) + case TaskFailed(): + return apply_task_failed(event, state) + case TaskStateUpdated(): + return apply_task_state_updated(event, state) + case WorkerStatusUpdated(): + return apply_worker_status_updated(event, state) + case TopologyNodeCreated(): + return apply_topology_node_created(event, state) + case TopologyEdgeCreated(): + return apply_topology_edge_created(event, state) + case TopologyEdgeDeleted(): + return apply_topology_edge_deleted(event, state) -def apply(state: State, event: EventFromEventLog[Event]) -> State: +def apply(state: State, event: IndexedEvent) -> State: + # Just to test that events are only applied in correct order + if state.last_event_applied_idx != event.idx - 1: + logger.warning( + f"Expected event {state.last_event_applied_idx + 1} but received {event.idx}" + ) + assert state.last_event_applied_idx == event.idx - 1 new_state: State = event_apply(event.event, state) - return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log}) + return new_state.model_copy(update={"last_event_applied_idx": event.idx}) -@event_apply.register(TaskCreated) def apply_task_created(event: TaskCreated, state: State) -> State: new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task} return state.model_copy(update={"tasks": new_tasks}) -@event_apply.register(TaskDeleted) def apply_task_deleted(event: TaskDeleted, state: State) -> State: new_tasks: Mapping[TaskId, Task] = { tid: task for tid, task in state.tasks.items() if tid != event.task_id @@ -69,7 +94,6 @@ def apply_task_deleted(event: TaskDeleted, state: State) -> State: return state.model_copy(update={"tasks": new_tasks}) -@event_apply.register(TaskStateUpdated) def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State: if event.task_id not in state.tasks: return state @@ -86,7 +110,6 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State: return state.model_copy(update={"tasks": new_tasks}) -@event_apply.register(TaskFailed) def apply_task_failed(event: TaskFailed, state: State) -> State: if event.task_id not in state.tasks: return state @@ -98,7 +121,6 @@ def apply_task_failed(event: TaskFailed, state: State) -> State: return state.model_copy(update={"tasks": new_tasks}) -@event_apply.register(InstanceCreated) def apply_instance_created(event: InstanceCreated, state: State) -> State: instance = event.instance new_instances: Mapping[InstanceId, Instance] = { @@ -108,13 +130,12 @@ def apply_instance_created(event: InstanceCreated, state: State) -> State: return state.model_copy(update={"instances": new_instances}) -@event_apply.register(InstanceActivated) def apply_instance_activated(event: InstanceActivated, state: State) -> State: if event.instance_id not in state.instances: return state updated_instance = state.instances[event.instance_id].model_copy( - update={"type": InstanceStatus.ACTIVE} + update={"instance_type": InstanceStatus.ACTIVE} ) new_instances: Mapping[InstanceId, Instance] = { **state.instances, @@ -123,13 +144,12 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State: return state.model_copy(update={"instances": new_instances}) -@event_apply.register(InstanceDeactivated) def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> State: if event.instance_id not in state.instances: return state updated_instance = state.instances[event.instance_id].model_copy( - update={"type": InstanceStatus.INACTIVE} + update={"instance_type": InstanceStatus.INACTIVE} ) new_instances: Mapping[InstanceId, Instance] = { **state.instances, @@ -138,7 +158,6 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat return state.model_copy(update={"instances": new_instances}) -@event_apply.register(InstanceDeleted) def apply_instance_deleted(event: InstanceDeleted, state: State) -> State: new_instances: Mapping[InstanceId, Instance] = { iid: inst for iid, inst in state.instances.items() if iid != event.instance_id @@ -146,19 +165,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State: return state.model_copy(update={"instances": new_instances}) -@event_apply.register(InstanceReplacedAtomically) -def apply_instance_replaced_atomically( - event: InstanceReplacedAtomically, state: State -) -> State: - new_instances = dict(state.instances) - if event.instance_to_replace in new_instances: - del new_instances[event.instance_to_replace] - if event.new_instance_id in state.instances: - new_instances[event.new_instance_id] = state.instances[event.new_instance_id] - return state.model_copy(update={"instances": new_instances}) - - -@event_apply.register(RunnerStatusUpdated) def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State: new_runners: Mapping[RunnerId, RunnerStatus] = { **state.runners, @@ -167,7 +173,6 @@ def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> Sta return state.model_copy(update={"runners": new_runners}) -@event_apply.register(RunnerDeleted) def apply_runner_deleted(event: RunnerDeleted, state: State) -> State: new_runners: Mapping[RunnerId, RunnerStatus] = { rid: rs for rid, rs in state.runners.items() if rid != event.runner_id @@ -175,7 +180,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State: return state.model_copy(update={"runners": new_runners}) -@event_apply.register(NodePerformanceMeasured) +# TODO: This whole function needs fixing def apply_node_performance_measured( event: NodePerformanceMeasured, state: State ) -> State: @@ -187,58 +192,39 @@ def apply_node_performance_measured( topology = copy.copy(state.topology) if not topology.contains_node(event.node_id): # TODO: figure out why this is happening in the first place - topology.add_node(Node(node_id=event.node_id)) + topology.add_node(NodeInfo(node_id=event.node_id)) topology.update_node_profile(event.node_id, event.node_profile) return state.model_copy(update={"topology": topology}) -@event_apply.register(WorkerStatusUpdated) def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> State: - new_node_status: Mapping[NodeId, NodeStatus] = { + new_node_status: Mapping[NodeId, WorkerStatus] = { **state.node_status, event.node_id: event.node_state, } return state.model_copy(update={"node_status": new_node_status}) -@event_apply.register(TopologyNodeCreated) def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State: topology = copy.copy(state.topology) - topology.add_node(Node(node_id=event.node_id)) - if event.role == "MASTER": - topology.set_master_node_id(event.node_id) + topology.add_node(NodeInfo(node_id=event.node_id)) return state.model_copy(update={"topology": topology}) -@event_apply.register(TopologyEdgeCreated) def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State: topology = copy.copy(state.topology) topology.add_connection(event.edge) return state.model_copy(update={"topology": topology}) -@event_apply.register(TopologyEdgeReplacedAtomically) -def apply_topology_edge_replaced_atomically( - event: TopologyEdgeReplacedAtomically, state: State -) -> State: - topology = copy.copy(state.topology) - topology.update_connection_profile(event.edge) - return state.model_copy(update={"topology": topology}) - - -@event_apply.register(TopologyEdgeDeleted) def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State: topology = copy.copy(state.topology) if not topology.contains_connection(event.edge): return state topology.remove_connection(event.edge) - opposite_edge = Connection( - local_node_id=event.edge.send_back_node_id, - send_back_node_id=event.edge.local_node_id, - local_multiaddr=event.edge.send_back_multiaddr, - send_back_multiaddr=event.edge.local_multiaddr, - ) - if not topology.contains_connection(opposite_edge): - return state.model_copy(update={"topology": topology}) - topology.remove_connection(opposite_edge) + if not topology.contains_connection(event.edge) and topology.contains_connection( + event.edge.reverse() + ): + topology.remove_connection(event.edge.reverse()) + # TODO: Clean up removing the reverse connection return state.model_copy(update={"topology": topology}) diff --git a/src/exo/shared/apply/__init__.py b/src/exo/shared/apply/__init__.py deleted file mode 100644 index dc22de1e..00000000 --- a/src/exo/shared/apply/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .apply import apply - -__all__ = ["apply"] diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index 2be7d1f2..2961c686 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -21,8 +21,10 @@ EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring" EXO_IPC_DIR = EXO_HOME / "ipc" # libp2p topics for event forwarding -LIBP2P_WORKER_EVENTS_TOPIC = "worker_events" +LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events" LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events" +LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message" +LIBP2P_COMMANDS_TOPIC = "commands" # lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip. LB_TFLOPS = 2.3 diff --git a/src/exo/shared/db/__init__.py b/src/exo/shared/db/__init__.py index 955a46e2..e69de29b 100644 --- a/src/exo/shared/db/__init__.py +++ b/src/exo/shared/db/__init__.py @@ -1,5 +0,0 @@ -"""Database implementations for event storage.""" - -from .sqlite import AsyncSQLiteEventStorage, EventStorageProtocol - -__all__ = ["AsyncSQLiteEventStorage", "EventStorageProtocol"] diff --git a/src/exo/shared/db/config.py b/src/exo/shared/db/config.py new file mode 100644 index 00000000..c5d0e01b --- /dev/null +++ b/src/exo/shared/db/config.py @@ -0,0 +1,19 @@ +from pathlib import Path + +from pydantic import BaseModel + +from exo.shared.constants import EXO_GLOBAL_EVENT_DB + + +class EventLogConfig(BaseModel): + """Configuration for the event log system""" + + # Batch processing settings + batch_size: int = 100 + batch_timeout_ms: int = 100 + debounce_ms: int = 10 + max_age_ms: int = 100 + + def get_db_path(self) -> Path: + """Get the full path for a specific event log type""" + return EXO_GLOBAL_EVENT_DB diff --git a/src/exo/shared/db/sqlite/connector.py b/src/exo/shared/db/connector.py similarity index 92% rename from src/exo/shared/db/sqlite/connector.py rename to src/exo/shared/db/connector.py index 5cb514b8..141cac38 100644 --- a/src/exo/shared/db/sqlite/connector.py +++ b/src/exo/shared/db/connector.py @@ -8,13 +8,13 @@ from pathlib import Path from typing import Any, cast from loguru import logger +from pydantic import TypeAdapter from sqlalchemy import text from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, create_async_engine -from exo.shared.types.events import Event, EventParser, NodeId -from exo.shared.types.events._events import Heartbeat -from exo.shared.types.events.components import EventFromEventLog +from exo.shared.types.common import NodeId +from exo.shared.types.events import Event, IndexedEvent, event_tag from .types import StoredEvent @@ -73,9 +73,7 @@ class AsyncSQLiteEventStorage: for event in events: await self._write_queue.put((event, origin)) - async def get_events_since( - self, last_idx: int, ignore_no_op_events: bool = False - ) -> Sequence[EventFromEventLog[Event]]: + async def get_events_since(self, last_idx: int) -> Sequence[IndexedEvent]: """Retrieve events after a specific index.""" if self._closed: raise RuntimeError("Storage is closed") @@ -92,10 +90,10 @@ class AsyncSQLiteEventStorage: ) rows = result.fetchall() - events: list[EventFromEventLog[Event]] = [] + events: list[IndexedEvent] = [] for row in rows: rowid: int = cast(int, row[0]) - origin: str = cast(str, row[1]) + # origin: str = cast(str, row[1]) # Parse JSON string to dict raw_event_data = row[2] # type: ignore[reportAny] - SQLAlchemy result is Any if isinstance(raw_event_data, str): @@ -104,14 +102,12 @@ class AsyncSQLiteEventStorage: ) else: event_data = cast(dict[str, Any], raw_event_data) - event = EventParser.validate_python(event_data) - if ignore_no_op_events and event.__no_apply__: - continue + event: Event = TypeAdapter(Event).validate_python(event_data) # type: ignore events.append( - EventFromEventLog( - event=event, - origin=NodeId(origin), - idx_in_log=rowid, # rowid becomes idx_in_log + IndexedEvent( + event=event, # type: ignore + # origin=NodeId(origin), + idx=rowid, # rowid becomes idx_in_log ) ) @@ -325,7 +321,7 @@ class AsyncSQLiteEventStorage: for event, origin in batch: stored_event = StoredEvent( origin=origin, - event_type=event.event_type, + event_type=event_tag(event), event_id=str(event.event_id), event_data=event.model_dump( mode="json" @@ -334,8 +330,7 @@ class AsyncSQLiteEventStorage: session.add(stored_event) await session.commit() - if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0: - logger.debug(f"Committed batch of {len(batch)} events") + logger.debug(f"Committed batch of {len(batch)} events") except OperationalError as e: if "database is locked" in str(e): @@ -393,7 +388,7 @@ class AsyncSQLiteEventStorage: for event, origin in batch: stored_event = StoredEvent( origin=origin, - event_type=event.event_type, + event_type=event_tag(event), event_id=str(event.event_id), event_data=event.model_dump(mode="json"), ) @@ -401,10 +396,9 @@ class AsyncSQLiteEventStorage: await session.commit() - if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0: - logger.debug( - f"Committed batch of {len(batch)} events after {retry_count} retries" - ) + logger.debug( + f"Committed batch of {len(batch)} events after {retry_count} retries" + ) return except OperationalError as e: diff --git a/src/exo/shared/db/event_log_manager.py b/src/exo/shared/db/event_log_manager.py new file mode 100644 index 00000000..b2fd3b18 --- /dev/null +++ b/src/exo/shared/db/event_log_manager.py @@ -0,0 +1,110 @@ +import asyncio +from typing import cast + +from loguru import logger +from sqlalchemy.exc import OperationalError + +from exo.shared.constants import EXO_HOME +from exo.shared.db.config import EventLogConfig +from exo.shared.db.connector import AsyncSQLiteEventStorage +from exo.utils.fs import ensure_directory_exists + + +class EventLogManager: + """ + Manages both worker and global event log connectors. + Used by both master and worker processes with different access patterns: + + - Worker: writes to worker_events, tails global_events + - Master (elected): writes to global_events, tails global_events + - Master (replica): writes to worker_events, tails global_events + """ + + def __init__(self, config: EventLogConfig): + self._config = config + self._connector: AsyncSQLiteEventStorage | None = None + + # Ensure base directory exists + ensure_directory_exists(EXO_HOME) + + # TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor. + async def initialize(self, max_retries: int = 3) -> None: + """Initialize both connectors with retry logic - call this during startup""" + # Both master and worker need both connectors + retry_count: int = 0 + last_error: Exception | None = None + + while retry_count < max_retries: + try: + await self.get_connector() + break + except OperationalError as e: + last_error = e + if "database is locked" in str(e) and retry_count < max_retries - 1: + retry_count += 1 + delay = cast(float, 0.5 * (2**retry_count)) + logger.warning( + f"Database locked while initializing db, retry {retry_count}/{max_retries} after {delay}s" + ) + await asyncio.sleep(delay) + else: + logger.opt(exception=e).error( + f"Failed to initialize db after {retry_count + 1} attempts" + ) + raise RuntimeError( + f"Could not initialize db after {retry_count + 1} attempts" + ) from e + except Exception as e: + logger.opt(exception=e).error("Unexpected error initializing db") + raise + + if retry_count >= max_retries and last_error: + raise RuntimeError( + f"Could not initialize db after {max_retries} attempts" + ) from last_error + logger.bind(user_facing=True).info("Initialized all event log connectors") + + async def get_connector(self) -> AsyncSQLiteEventStorage: + """Get or create a connector for the specified log type""" + if not self._connector: + db_path = self._config.get_db_path() + + try: + connector = AsyncSQLiteEventStorage( + db_path=db_path, + batch_size=self._config.batch_size, + batch_timeout_ms=self._config.batch_timeout_ms, + debounce_ms=self._config.debounce_ms, + max_age_ms=self._config.max_age_ms, + ) + + # Start the connector (creates tables if needed) + await connector.start() + + self._connector = connector + logger.bind(user_facing=True).info( + f"Initialized db connector at {db_path}" + ) + except Exception as e: + logger.bind(user_facing=True).opt(exception=e).error( + "Failed to create db connector" + ) + raise + + return self._connector + + @property + def events(self) -> AsyncSQLiteEventStorage: + """Access event log (must call initialize() first)""" + if not self._connector: + raise RuntimeError( + "Event log manager not initialized. Call initialize() first." + ) + return self._connector + + async def close(self) -> None: + """Close all open connectors""" + assert self._connector is not None + await self._connector.close() + logger.bind(user_facing=True).info("Closed db connector") + self._connector = None diff --git a/src/exo/shared/db/sqlite/__init__.py b/src/exo/shared/db/sqlite/__init__.py deleted file mode 100644 index d6c08ef5..00000000 --- a/src/exo/shared/db/sqlite/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""SQLite event storage implementation.""" - -from .config import EventLogConfig, EventLogType -from .connector import AsyncSQLiteEventStorage -from .event_log_manager import EventLogManager -from .types import EventStorageProtocol, StoredEvent - -__all__ = [ - "AsyncSQLiteEventStorage", - "EventLogConfig", - "EventLogManager", - "EventLogType", - "EventStorageProtocol", - "StoredEvent", -] diff --git a/src/exo/shared/db/sqlite/config.py b/src/exo/shared/db/sqlite/config.py deleted file mode 100644 index f6f6ac97..00000000 --- a/src/exo/shared/db/sqlite/config.py +++ /dev/null @@ -1,32 +0,0 @@ -from enum import Enum -from pathlib import Path - -from pydantic import BaseModel - -from exo.shared.constants import EXO_GLOBAL_EVENT_DB, EXO_WORKER_EVENT_DB - - -class EventLogType(str, Enum): - """Types of event logs in the system""" - - WORKER_EVENTS = "worker_events" - GLOBAL_EVENTS = "global_events" - - -class EventLogConfig(BaseModel): - """Configuration for the event log system""" - - # Batch processing settings - batch_size: int = 100 - batch_timeout_ms: int = 100 - debounce_ms: int = 10 - max_age_ms: int = 100 - - def get_db_path(self, log_type: EventLogType) -> Path: - """Get the full path for a specific event log type""" - if log_type == EventLogType.WORKER_EVENTS: - return EXO_WORKER_EVENT_DB - elif log_type == EventLogType.GLOBAL_EVENTS: - return EXO_GLOBAL_EVENT_DB - else: - raise ValueError(f"Unknown log type: {log_type}") diff --git a/src/exo/shared/db/sqlite/event_log_manager.py b/src/exo/shared/db/sqlite/event_log_manager.py deleted file mode 100644 index 00144ffc..00000000 --- a/src/exo/shared/db/sqlite/event_log_manager.py +++ /dev/null @@ -1,122 +0,0 @@ -import asyncio -from typing import Dict, Optional, cast - -from loguru import logger -from sqlalchemy.exc import OperationalError - -from exo.shared.constants import EXO_HOME -from exo.shared.db.sqlite.config import EventLogConfig, EventLogType -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage -from exo.shared.utils.fs import ensure_directory_exists - - -class EventLogManager: - """ - Manages both worker and global event log connectors. - Used by both master and worker processes with different access patterns: - - - Worker: writes to worker_events, tails global_events - - Master (elected): writes to global_events, tails global_events - - Master (replica): writes to worker_events, tails global_events - """ - - def __init__(self, config: EventLogConfig): - self._config = config - self._connectors: Dict[EventLogType, AsyncSQLiteEventStorage] = {} - - # Ensure base directory exists - ensure_directory_exists(EXO_HOME) - - # TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor. - async def initialize(self, max_retries: int = 3) -> None: - """Initialize both connectors with retry logic - call this during startup""" - # Both master and worker need both connectors - for log_type in [EventLogType.WORKER_EVENTS, EventLogType.GLOBAL_EVENTS]: - retry_count: int = 0 - last_error: Optional[Exception] = None - - while retry_count < max_retries: - try: - await self.get_connector(log_type) - break - except OperationalError as e: - last_error = e - if "database is locked" in str(e) and retry_count < max_retries - 1: - retry_count += 1 - delay = cast(float, 0.5 * (2**retry_count)) - logger.warning( - f"Database locked while initializing {log_type.value}, retry {retry_count}/{max_retries} after {delay}s" - ) - await asyncio.sleep(delay) - else: - logger.opt(exception=e).error( - f"Failed to initialize {log_type.value} after {retry_count + 1} attempts" - ) - raise RuntimeError( - f"Could not initialize {log_type.value} database after {retry_count + 1} attempts" - ) from e - except Exception as e: - logger.opt(exception=e).error( - f"Unexpected error initializing {log_type.value}" - ) - raise - - if retry_count >= max_retries and last_error: - raise RuntimeError( - f"Could not initialize {log_type.value} database after {max_retries} attempts" - ) from last_error - logger.bind(user_facing=True).info("Initialized all event log connectors") - - async def get_connector(self, log_type: EventLogType) -> AsyncSQLiteEventStorage: - """Get or create a connector for the specified log type""" - if log_type not in self._connectors: - db_path = self._config.get_db_path(log_type) - - try: - connector = AsyncSQLiteEventStorage( - db_path=db_path, - batch_size=self._config.batch_size, - batch_timeout_ms=self._config.batch_timeout_ms, - debounce_ms=self._config.debounce_ms, - max_age_ms=self._config.max_age_ms, - ) - - # Start the connector (creates tables if needed) - await connector.start() - - self._connectors[log_type] = connector - logger.bind(user_facing=True).info( - f"Initialized {log_type.value} connector at {db_path}" - ) - except Exception as e: - logger.bind(user_facing=True).opt(exception=e).error( - f"Failed to create {log_type.value} connector" - ) - raise - - return self._connectors[log_type] - - @property - def worker_events(self) -> AsyncSQLiteEventStorage: - """Access worker events log (must call initialize() first)""" - if EventLogType.WORKER_EVENTS not in self._connectors: - raise RuntimeError( - "Event log manager not initialized. Call initialize() first." - ) - return self._connectors[EventLogType.WORKER_EVENTS] - - @property - def global_events(self) -> AsyncSQLiteEventStorage: - """Access global events log (must call initialize() first)""" - if EventLogType.GLOBAL_EVENTS not in self._connectors: - raise RuntimeError( - "Event log manager not initialized. Call initialize() first." - ) - return self._connectors[EventLogType.GLOBAL_EVENTS] - - async def close_all(self) -> None: - """Close all open connectors""" - for log_type, connector in self._connectors.items(): - await connector.close() - logger.bind(user_facing=True).info(f"Closed {log_type.value} connector") - self._connectors.clear() diff --git a/src/exo/shared/db/sqlite/types.py b/src/exo/shared/db/types.py similarity index 51% rename from src/exo/shared/db/sqlite/types.py rename to src/exo/shared/db/types.py index 5fc0f582..0795e3d0 100644 --- a/src/exo/shared/db/sqlite/types.py +++ b/src/exo/shared/db/types.py @@ -1,13 +1,9 @@ from datetime import datetime, timezone -from typing import Any, Protocol, Sequence +from typing import Any from sqlalchemy import DateTime, Index from sqlmodel import JSON, Column, Field, SQLModel -from exo.shared.types.common import NodeId -from exo.shared.types.events import Event -from exo.shared.types.events.components import EventFromEventLog - class StoredEvent(SQLModel, table=True): """SQLite representation of an event in the event log. @@ -29,28 +25,3 @@ class StoredEvent(SQLModel, table=True): ) __table_args__ = (Index("idx_events_origin_created", "origin", "created_at"),) - - -class EventStorageProtocol(Protocol): - """Protocol for event storage implementations.""" - - async def append_events(self, events: Sequence[Event], origin: NodeId) -> None: - """Append events to the log (fire-and-forget). - - Events are queued for batched writing and assigned idx_in_log - when committed to storage. - """ - ... - - async def get_events_since( - self, last_idx: int - ) -> Sequence[EventFromEventLog[Event]]: - """Retrieve events after a specific index. - - Returns events in idx_in_log order. - """ - ... - - async def close(self) -> None: - """Close the storage connection and cleanup resources.""" - ... diff --git a/src/exo/shared/election.py b/src/exo/shared/election.py new file mode 100644 index 00000000..a5f94c66 --- /dev/null +++ b/src/exo/shared/election.py @@ -0,0 +1,183 @@ +from typing import Self + +import anyio +from anyio import ( + CancelScope, + Event, + create_task_group, + get_cancelled_exc_class, +) +from anyio.abc import TaskGroup +from loguru import logger + +from exo.routing.connection_message import ConnectionMessage +from exo.shared.types.common import NodeId +from exo.utils.channels import Receiver, Sender +from exo.utils.pydantic_ext import CamelCaseModel + +ELECTION_TIMEOUT = 3.0 + + +class ElectionMessage(CamelCaseModel): + clock: int + seniority: int + node_id: NodeId + + # Could eventually include a list of neighbour nodes for centrality + def __lt__(self, other: Self): + if self.seniority != other.seniority: + return self.seniority < other.seniority + else: + return self.node_id < other.node_id + + +class ElectionResult(CamelCaseModel): + node_id: NodeId + is_new_master: bool + historic_messages: list[ConnectionMessage] + + +class Election: + def __init__( + self, + node_id: NodeId, + election_message_receiver: Receiver[ElectionMessage], + election_message_sender: Sender[ElectionMessage], + election_result_sender: Sender[ElectionResult], + connection_message_receiver: Receiver[ConnectionMessage], + *, + is_candidate: bool = True, + seniority: int = 0, + ): + # If we aren't a candidate, simply don't increment seniority. + # For reference: This node can be elected master if all nodes are not master candidates + # Any master candidate will automatically win out over this node. + self.seniority = seniority if is_candidate else -1 + self.clock = 0 + self.node_id = node_id + # Every node spawns as master + self.master_node_id: NodeId = node_id + + self._em_sender = election_message_sender + self._em_receiver = election_message_receiver + self._er_sender = election_result_sender + self._cm_receiver = connection_message_receiver + + # Campaign state + self._candidates: list[ElectionMessage] = [] + self._campaign_cancel_scope: CancelScope | None = None + self._campaign_done: Event | None = None + self._tg: TaskGroup | None = None + self._connection_messages: list[ConnectionMessage] = [] + + async def run(self): + logger.info("Starting Election") + async with create_task_group() as tg: + self._tg = tg + tg.start_soon(self._election_receiver) + tg.start_soon(self._connection_receiver) + await self._campaign(None) + + if self._campaign_cancel_scope is not None: + self._campaign_cancel_scope.cancel() + # Only exit once the latest campaign has finished + if self._campaign_done is not None: + await self._campaign_done.wait() + + async def elect(self, node_id: NodeId) -> None: + is_new_master = node_id != self.master_node_id + self.master_node_id = node_id + await self._er_sender.send( + ElectionResult( + node_id=node_id, + is_new_master=is_new_master, + historic_messages=self._connection_messages, + ) + ) + + async def shutdown(self) -> None: + if not self._tg: + logger.warning( + "Attempted to shutdown election service that was not running" + ) + return + self._tg.cancel_scope.cancel() + + async def _election_receiver(self) -> None: + with self._em_receiver as election_messages: + async for message in election_messages: + if message.node_id == self.node_id: + # Drop messages from us (See exo.routing.router) + continue + # If a new round is starting, we participate + if message.clock > self.clock: + self.clock = message.clock + await self._campaign(message) + continue + # Dismiss old messages + if message.clock < self.clock: + continue + logger.debug(f"Election added candidate {message}") + # Now we are processing this rounds messages - including the message that triggered this round. + self._candidates.append(message) + + async def _connection_receiver(self) -> None: + with self._cm_receiver as connection_messages: + async for msg in connection_messages: + # These messages are strictly peer to peer + self.clock += 1 + await self._campaign(None) + self._connection_messages.append(msg) + + async def _campaign(self, initial_message: ElectionMessage | None) -> None: + # Kill the old campaign + if self._campaign_cancel_scope: + self._campaign_cancel_scope.cancel() + if self._campaign_done: + await self._campaign_done.wait() + + candidates: list[ElectionMessage] = [] + if initial_message: + candidates.append(initial_message) + self._candidates = candidates + done = Event() + self._campaign_done = done + + assert self._tg is not None, ( + "Election campaign started before election service initialized" + ) + # Spin off a new campaign + self._tg.start_soon(self._complete_campaign, self.clock, candidates, done) + + async def _complete_campaign( + self, clock: int, candidates: list[ElectionMessage], done: Event + ) -> None: + scope = CancelScope() + try: + with scope: + self._campaign_cancel_scope = scope + logger.info(f"Election {clock} started") + + candidates.append(self._election_status(clock)) + await self._em_sender.send(self._election_status(clock)) + + await anyio.sleep(ELECTION_TIMEOUT) + + # Election finished! + candidates = sorted(candidates) + logger.debug(f"Election queue {candidates}") + elected = candidates[-1] + logger.info("Election finished") + if self.node_id == elected.node_id and self.seniority >= 0: + self.seniority = max(self.seniority, len(candidates)) + await self.elect(elected.node_id) + except get_cancelled_exc_class(): + logger.info("Election cancelled") + finally: + if self._campaign_cancel_scope is scope: + self._campaign_cancel_scope = None + done.set() + + def _election_status(self, clock: int | None = None) -> ElectionMessage: + c = self.clock if clock is None else clock + return ElectionMessage(clock=c, seniority=self.seniority, node_id=self.node_id) diff --git a/src/exo/shared/env.py b/src/exo/shared/env.py deleted file mode 100644 index c87cf094..00000000 --- a/src/exo/shared/env.py +++ /dev/null @@ -1,28 +0,0 @@ -import logging -import os -from typing import TypeVar - -from pydantic import BaseModel, ConfigDict, ValidationError - -env_model_config = ConfigDict( - strict=True, - frozen=True, - extra="forbid", -) - - -class BaseEnv(BaseModel): - model_config = env_model_config - - -EnvSchema = TypeVar("EnvSchema", bound=BaseEnv) - - -def get_validated_env( - environment_schema: type[EnvSchema], logger: logging.Logger -) -> EnvSchema: - try: - return environment_schema.model_validate(os.environ, strict=True) - except ValidationError as e: - logger.error("Environment Variables Validation Failed: %s", e) - raise e diff --git a/src/exo/shared/global_conn.py b/src/exo/shared/global_conn.py index 5def2999..7ecf5928 100644 --- a/src/exo/shared/global_conn.py +++ b/src/exo/shared/global_conn.py @@ -18,6 +18,7 @@ class AsyncConnection[SendT, RecvT]: - await send(...) from asyncio code - send_sync(...) from executor/background threads """ + def __init__(self, conn: Connection): self._conn = conn self._send_lock = threading.Lock() @@ -44,7 +45,7 @@ class AsyncConnection[SendT, RecvT]: def _recv_blocking(self) -> RecvT: # Not strictly needed in your parent, but safe if misused elsewhere with self._recv_lock: - return self._conn.recv() # type: ignore[no-any-return] + return self._conn.recv() # type: ignore[no-any-return] async def poll(self, timeout: float | None = None) -> bool: return await asyncio.to_thread(self._conn.poll, timeout) @@ -52,12 +53,15 @@ class AsyncConnection[SendT, RecvT]: def close(self) -> None: self._conn.close() + _conn: Optional[AsyncConnection[RunnerResponse, RunnerMessage]] = None + def set_conn(c: AsyncConnection[RunnerResponse, RunnerMessage]) -> None: global _conn _conn = c + def get_conn() -> AsyncConnection[RunnerResponse, RunnerMessage]: if _conn is None: raise RuntimeError("Global conn has not been set yet") diff --git a/src/exo/shared/ipc/file_mutex/flock_mutex.py b/src/exo/shared/ipc/file_mutex/flock_mutex.py index fda65d60..da486dbf 100644 --- a/src/exo/shared/ipc/file_mutex/flock_mutex.py +++ b/src/exo/shared/ipc/file_mutex/flock_mutex.py @@ -12,7 +12,7 @@ import time from enum import Enum from typing import Optional -from exo.shared.utils.fs import StrPath, ensure_parent_directory_exists +from exo.utils.fs import StrPath, ensure_parent_directory_exists # open in read-write mode, creates file if it doesn't exist already, # closes this file descriptor in any children processes (prevents FD leaking), diff --git a/src/exo/shared/ipc/pipe_duplex.py b/src/exo/shared/ipc/pipe_duplex.py index 0f1f3178..caea9922 100644 --- a/src/exo/shared/ipc/pipe_duplex.py +++ b/src/exo/shared/ipc/pipe_duplex.py @@ -33,7 +33,7 @@ from typing import Callable from cobs import cobs # pyright: ignore[reportMissingTypeStubs] from pytest import LogCaptureFixture -from exo.shared.utils.fs import ( +from exo.utils.fs import ( StrPath, delete_if_exists, ensure_parent_directory_exists, diff --git a/src/exo/shared/keypair.py b/src/exo/shared/keypair.py deleted file mode 100644 index a78c2cb4..00000000 --- a/src/exo/shared/keypair.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import hashlib -import logging -import os -from pathlib import Path -from typing import final - -import base58 -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ed25519 -from filelock import FileLock - -from exo.shared.constants import EXO_NODE_ID_KEYPAIR - - -@final -class PeerId: - """ - A libp2p peer identifier derived from a cryptographic public key. - Compatible with py-libp2p's PeerID interface. - """ - - def __init__(self, peer_id_bytes: bytes) -> None: - self._bytes = peer_id_bytes - - @staticmethod - def from_bytes(data: bytes) -> "PeerId": - """Create PeerId from raw bytes.""" - return PeerId(data) - - @staticmethod - def from_public_key(public_key_bytes: bytes) -> "PeerId": - """Create PeerId from a public key by hashing it.""" - # For Ed25519 keys, libp2p uses the identity hash (no hashing) for keys <= 42 bytes - # Since Ed25519 public keys are 32 bytes, we use identity hash - if len(public_key_bytes) <= 42: - return PeerId(public_key_bytes) - else: - # For larger keys, use SHA-256 - hash_digest = hashlib.sha256(public_key_bytes).digest() - return PeerId(hash_digest) - - def to_bytes(self) -> bytes: - """Return the raw bytes of this PeerId.""" - return self._bytes - - def to_base58(self) -> str: - """Return the base58-encoded string representation.""" - return base58.b58encode(self._bytes).decode("ascii") - - def __str__(self) -> str: - """Return the base58-encoded string representation.""" - return self.to_base58() - - def __repr__(self) -> str: - """Return debug representation.""" - return f"PeerId('{self.to_base58()}')" - - def __eq__(self, other: object) -> bool: - """Check equality with another PeerId.""" - if not isinstance(other, PeerId): - return False - return self._bytes == other._bytes - - def __hash__(self) -> int: - """Make PeerId hashable.""" - return hash(self._bytes) - - -@final -class Keypair: - """ - A py-libp2p compatible keypair implementation. - Provides the same interface as py-libp2p's KeyPair. - """ - - def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None: - self._private_key = private_key - self._public_key = private_key.public_key() - - @staticmethod - def generate_ed25519() -> "Keypair": - """Generate a new Ed25519 keypair.""" - private_key = ed25519.Ed25519PrivateKey.generate() - return Keypair(private_key) - - @staticmethod - def from_protobuf_encoding(data: bytes) -> "Keypair": - """ - Deserialize a keypair from libp2p protobuf encoding. - Compatible with py-libp2p's serialization format. - """ - if len(data) < 2: - raise ValueError("Invalid protobuf data: too short") - - # Simple protobuf parsing for our specific use case - # We expect: field 1 (type) as varint, field 2 (data) as bytes - offset = 0 - - # Parse type field (field tag 1, wire type 0 = varint) - if data[offset] != 0x08: # field 1, varint - raise ValueError("Expected type field") - offset += 1 - - key_type = data[offset] - offset += 1 - - if key_type != 1: # Ed25519 - raise ValueError(f"Unsupported key type: {key_type}") - - # Parse data field (field tag 2, wire type 2 = length-delimited) - if offset >= len(data) or data[offset] != 0x12: # field 2, bytes - raise ValueError("Expected data field") - offset += 1 - - # Parse length - data_length = data[offset] - offset += 1 - - if data_length not in (32, 64): - raise ValueError(f"Invalid Ed25519 private key length: {data_length}") - - if offset + data_length > len(data): - raise ValueError("Truncated private key data") - - key_data = data[offset : offset + data_length] - - try: - if data_length == 64: - # libp2p format: 32 bytes private key seed + 32 bytes public key - private_key_seed = key_data[:32] - private_key = ed25519.Ed25519PrivateKey.from_private_bytes( - private_key_seed - ) - else: - # Raw 32-byte private key - private_key = ed25519.Ed25519PrivateKey.from_private_bytes(key_data) - - return Keypair(private_key) - except Exception as e: - raise ValueError(f"Invalid Ed25519 private key: {e}") from e - - def to_protobuf_encoding(self) -> bytes: - """ - Serialize this keypair to libp2p protobuf encoding. - Compatible with py-libp2p's serialization format. - """ - private_key_bytes = self._private_key.private_bytes( - encoding=serialization.Encoding.Raw, - format=serialization.PrivateFormat.Raw, - encryption_algorithm=serialization.NoEncryption(), - ) - - public_key_bytes = self._public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) - - # libp2p Ed25519 format: private key seed (32) + public key (32) - combined_key_data = private_key_bytes + public_key_bytes - - # Build protobuf manually for our simple case - # Field 1 (type): tag=0x08, value=1 (Ed25519) - # Field 2 (data): tag=0x12, length=64, data=combined_key_data - result = bytearray() - result.extend([0x08, 0x01]) # field 1: type = 1 (Ed25519) - result.extend([0x12, 0x40]) # field 2: length = 64 bytes - result.extend(combined_key_data) - - return bytes(result) - - def to_peer_id(self) -> PeerId: - """Generate a PeerId from this keypair's public key.""" - public_key_bytes = self._public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) - return PeerId.from_public_key(public_key_bytes) - - def sign(self, data: bytes) -> bytes: - """Sign data with this keypair's private key.""" - return self._private_key.sign(data) - - def verify(self, data: bytes, signature: bytes) -> bool: - """Verify a signature against data using this keypair's public key.""" - try: - self._public_key.verify(signature, data) - return True - except Exception: - return False - - @property - def public_key_bytes(self) -> bytes: - """Get the raw public key bytes.""" - return self._public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) - - @property - def private_key_bytes(self) -> bytes: - """Get the raw private key bytes.""" - return self._private_key.private_bytes( - encoding=serialization.Encoding.Raw, - format=serialization.PrivateFormat.Raw, - encryption_algorithm=serialization.NoEncryption(), - ) - - # py-libp2p compatibility properties - @property - def private_key(self) -> ed25519.Ed25519PrivateKey: - """Access to the underlying private key for py-libp2p compatibility.""" - return self._private_key - - @property - def public_key(self) -> ed25519.Ed25519PublicKey: - """Access to the underlying public key for py-libp2p compatibility.""" - return self._public_key - - -def get_node_id_keypair( - path: str | bytes | os.PathLike[str] | os.PathLike[bytes] = EXO_NODE_ID_KEYPAIR, -) -> Keypair: - """ - Obtains the :class:`Keypair` associated with this node-ID. - Obtain the :class:`PeerId` by from it. - """ - - def lock_path(path: str | bytes | os.PathLike[str] | os.PathLike[bytes]) -> Path: - return Path(str(path) + ".lock") - - # operate with cross-process lock to avoid race conditions - with FileLock(lock_path(path)): - with open(path, "a+b") as f: # opens in append-mode => starts at EOF - # if non-zero EOF, then file exists => use to get node-ID - if f.tell() != 0: - f.seek(0) # go to start & read protobuf-encoded bytes - protobuf_encoded = f.read() - - try: # if decoded successfully, save & return - return Keypair.from_protobuf_encoding(protobuf_encoded) - except ValueError as e: # on runtime error, assume corrupt file - logging.warning( - f"Encountered error when trying to get keypair: {e}" - ) - - # if no valid credentials, create new ones and persist - with open(path, "w+b") as f: - keypair = Keypair.generate_ed25519() - f.write(keypair.to_protobuf_encoding()) - return keypair diff --git a/src/exo/shared/logging.py b/src/exo/shared/logging.py index 2798ffbe..60705bf6 100644 --- a/src/exo/shared/logging.py +++ b/src/exo/shared/logging.py @@ -1,32 +1,13 @@ -from __future__ import annotations - import sys -from logging import Logger from pathlib import Path -import loguru from loguru import logger -from exo.shared.constants import EXO_TEST_LOG - - -def is_user_facing(record: loguru.Record) -> bool: - return ("user_facing" in record["extra"]) and record["extra"]["user_facing"] - def logger_setup(log_file: Path, verbosity: int = 0): """Set up logging for this process - formatting, file handles, verbosity and output""" logger.remove() if verbosity == 0: - _ = logger.add( # type: ignore - sys.__stderr__, # type: ignore - format="[ {time:hh:mm:ss.SSSSA} | {level: <8}] {message}", - level="INFO", - colorize=True, - enqueue=True, - filter=is_user_facing, - ) - elif verbosity == 1: _ = logger.add( # type: ignore sys.__stderr__, # type: ignore format="[ {time:hh:mm:ss.SSSSA} | {level: <8}] {message}", @@ -40,11 +21,12 @@ def logger_setup(log_file: Path, verbosity: int = 0): format="[ {time:HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} ] {message}", level="DEBUG", colorize=True, + enqueue=True, ) _ = logger.add( log_file, format="[ {time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} ] {message}", - level="DEBUG", + level="INFO", enqueue=True, ) @@ -52,10 +34,3 @@ def logger_setup(log_file: Path, verbosity: int = 0): def logger_cleanup(): """Flush all queues before shutting down so any in-flight logs are written to disk""" logger.complete() - - -def logger_test_install(py_logger: Logger): - """Installs a default python logger into the Loguru environment by capturing all its handlers - intended to be used for pytest compatibility, not within the main codebase""" - logger_setup(EXO_TEST_LOG, 3) - for handler in py_logger.handlers: - logger.add(handler) diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 4b47559a..52667413 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -1,11 +1,11 @@ from typing import List -from pydantic import BaseModel - -from exo.shared.types.models import ModelMetadata +from exo.shared.types.memory import Memory +from exo.shared.types.models import ModelId, ModelMetadata +from exo.utils.pydantic_ext import CamelCaseModel -class ModelCard(BaseModel): +class ModelCard(CamelCaseModel): short_id: str model_id: str name: str @@ -23,9 +23,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/DeepSeek-V3-0324-4bit", + model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"), pretty_name="DeepSeek V3 0324 (4-bit)", - storage_size_kilobytes=409706307, + storage_size=Memory.from_kb(409706307), n_layers=61, ), ), @@ -36,9 +36,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/DeepSeek-v3-0324-8bit", + model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"), pretty_name="DeepSeek V3 0324 (8-bit)", - storage_size_kilobytes=754706307, + storage_size=Memory.from_kb(754706307), n_layers=61, ), ), @@ -49,9 +49,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/DeepSeek-V3.1-8bit", + model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"), pretty_name="DeepSeek V3.1 (8-bit)", - storage_size_kilobytes=754706307, + storage_size=Memory.from_kb(754706307), n_layers=61, ), ), @@ -62,9 +62,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/DeepSeek-V3.1-4bit", + model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"), pretty_name="DeepSeek V3.1 (4-bit)", - storage_size_kilobytes=754706307 // 2, # TODO !!!!! + storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!! n_layers=61, ), ), @@ -76,9 +76,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/DeepSeek-R1-0528-4bit", + model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"), pretty_name="DeepSeek R1 671B (4-bit)", - storage_size_kilobytes=409706307, + storage_size=Memory.from_kb(409706307), n_layers=61, ), ), @@ -89,9 +89,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/DeepSeek-R1-0528-8bit", + model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"), pretty_name="DeepSeek R1 671B (8-bit)", - storage_size_kilobytes=754998771712 // 1024, + storage_size=Memory.from_bytes(754998771712), n_layers=61, ), ), @@ -103,9 +103,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", + model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"), pretty_name="Llama 3.1 8B", - storage_size_kilobytes=4411528, + storage_size=Memory.from_kb(4411528), n_layers=32, ), ), @@ -116,9 +116,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", + model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"), pretty_name="Llama 3.1 70B", - storage_size_kilobytes=38758160, + storage_size=Memory.from_kb(38758160), n_layers=80, ), ), @@ -130,9 +130,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", + model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"), pretty_name="Llama 3.2 1B", - storage_size_kilobytes=678948, + storage_size=Memory.from_kb(678948), n_layers=16, ), ), @@ -143,9 +143,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", + model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"), pretty_name="Llama 3.2 3B", - storage_size_kilobytes=1765062, + storage_size=Memory.from_kb(1765062), n_layers=28, ), ), @@ -157,9 +157,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Llama-3.3-70B-Instruct-4bit", + model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"), pretty_name="Llama 3.3 70B", - storage_size_kilobytes=38758160, + storage_size=Memory.from_kb(38758160), n_layers=80, ), ), @@ -171,9 +171,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Phi-3-mini-128k-instruct-4bit", + model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"), pretty_name="Phi 3 Mini 128k", - storage_size_kilobytes=2099262, + storage_size=Memory.from_kb(2099262), n_layers=32, ), ), @@ -184,9 +184,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Phi-3-mini-128k-instruct-4bit", + model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"), pretty_name="Phi 3 Mini 128k", - storage_size_kilobytes=2099262, + storage_size=Memory.from_kb(2099262), n_layers=32, ), ), @@ -198,9 +198,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Qwen3-0.6B-4bit", + model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"), pretty_name="Qwen3 0.6B", - storage_size_kilobytes=327512, + storage_size=Memory.from_kb(327512), n_layers=28, ), ), @@ -211,9 +211,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/Qwen3-30B-A3B-4bit", + model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"), pretty_name="Qwen3 30B (Active 3B)", - storage_size_kilobytes=16772092, + storage_size=Memory.from_kb(16772092), n_layers=48, ), ), @@ -225,9 +225,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/granite-3.3-2b-instruct-fp16", + model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"), pretty_name="Granite 3.3 2B", - storage_size_kilobytes=4948320, + storage_size=Memory.from_kb(4948320), n_layers=40, ), ), @@ -238,9 +238,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""", tags=[], metadata=ModelMetadata( - model_id="mlx-community/granite-3.3-8b-instruct-fp16", + model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"), pretty_name="Granite 3.3 8B", - storage_size_kilobytes=15958720, + storage_size=Memory.from_kb(15958720), n_layers=40, ), ), @@ -252,9 +252,9 @@ MODEL_CARDS: dict[str, ModelCard] = { description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """, tags=[], metadata=ModelMetadata( - model_id="mlx-community/SmolLM-135M-4bit", + model_id=ModelId("mlx-community/SmolLM-135M-4bit"), pretty_name="Smol LM 135M", - storage_size_kilobytes=73940, + storage_size=Memory.from_kb(73940), n_layers=30, ), ), diff --git a/src/exo/shared/models/model_meta.py b/src/exo/shared/models/model_meta.py index de54536f..9ed1f151 100644 --- a/src/exo/shared/models/model_meta.py +++ b/src/exo/shared/models/model_meta.py @@ -6,7 +6,8 @@ from huggingface_hub import model_info from loguru import logger from pydantic import BaseModel, Field -from exo.shared.types.models import ModelMetadata +from exo.shared.types.memory import Memory +from exo.shared.types.models import ModelId, ModelMetadata from exo.worker.download.download_utils import ( ModelSafetensorsIndex, download_file_with_retry, @@ -65,7 +66,7 @@ async def get_config_data(model_id: str) -> ConfigData: return ConfigData.model_validate_json(await f.read()) -async def get_safetensors_size(model_id: str) -> int: +async def get_safetensors_size(model_id: str) -> Memory: """Gets model size from safetensors index or falls back to HF API.""" target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--") await aios.makedirs(target_dir, exist_ok=True) @@ -83,12 +84,12 @@ async def get_safetensors_size(model_id: str) -> int: metadata = index_data.metadata if metadata is not None: - return metadata.total_size + return Memory.from_bytes(metadata.total_size) info = model_info(model_id) if info.safetensors is None: raise ValueError(f"No safetensors info found for {model_id}") - return info.safetensors.total + return Memory.from_bytes(info.safetensors.total) _model_meta_cache: Dict[str, ModelMetadata] = {} @@ -109,8 +110,8 @@ async def _get_model_meta(model_id: str) -> ModelMetadata: mem_size_bytes = await get_safetensors_size(model_id) return ModelMetadata( - model_id=model_id, + model_id=ModelId(model_id), pretty_name=model_id, - storage_size_kilobytes=mem_size_bytes // 1024, + storage_size=mem_size_bytes, n_layers=num_layers, ) diff --git a/src/exo/shared/tests/test_election.py b/src/exo/shared/tests/test_election.py new file mode 100644 index 00000000..1c04e5c1 --- /dev/null +++ b/src/exo/shared/tests/test_election.py @@ -0,0 +1,313 @@ +import pytest +from anyio import create_task_group, fail_after, move_on_after + +from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType +from exo.shared.election import Election, ElectionMessage, ElectionResult +from exo.shared.types.common import NodeId +from exo.utils.channels import channel + +# ======= # +# Helpers # +# ======= # + + +def em(clock: int, seniority: int, node_id: str) -> ElectionMessage: + return ElectionMessage(clock=clock, seniority=seniority, node_id=NodeId(node_id)) + + +@pytest.fixture +def fast_timeout(monkeypatch: pytest.MonkeyPatch): + # Keep campaigns fast; user explicitly allows tests to shorten the timeout. + import exo.shared.election as election_mod + + monkeypatch.setattr(election_mod, "ELECTION_TIMEOUT", 0.05, raising=True) + yield + + +# ======================================= # +# TESTS # +# ======================================= # + + +@pytest.mark.anyio +async def test_single_round_broadcasts_and_updates_seniority_on_self_win( + fast_timeout: None, +) -> None: + """ + Start a round by injecting an ElectionMessage with higher clock. + With only our node effectively 'winning', we should broadcast once and update seniority. + """ + # Outbound election messages from the Election (we'll observe these) + em_out_tx, em_out_rx = channel[ElectionMessage]() + # Inbound election messages to the Election (we'll inject these) + em_in_tx, em_in_rx = channel[ElectionMessage]() + # Election results produced by the Election (we'll observe these) + er_tx, er_rx = channel[ElectionResult]() + # Connection messages (unused in this test but required by ctor) + cm_tx, cm_rx = channel[ConnectionMessage]() + + election = Election( + node_id=NodeId("B"), + election_message_receiver=em_in_rx, + election_message_sender=em_out_tx, + election_result_sender=er_tx, + connection_message_receiver=cm_rx, + is_candidate=True, + ) + + async with create_task_group() as tg: + with fail_after(2): + tg.start_soon(election.run) + # Trigger new round at clock=1 (peer announces it) + await em_in_tx.send(em(clock=1, seniority=0, node_id="A")) + + # Expect our broadcast back to the peer side for this round only + while True: + got = await em_out_rx.receive() + if got.clock == 1 and got.node_id == NodeId("B"): + break + + # Wait for the round to finish and produce an ElectionResult + result = await er_rx.receive() + assert result.node_id == NodeId("B") + # We spawned as master; electing ourselves again is not "new master". + assert result.is_new_master is False + + # Close inbound streams to end the receivers (and run()) + await em_in_tx.aclose() + await cm_tx.aclose() + + # We should have updated seniority to 2 (A + B). + assert election.seniority == 2 + + +@pytest.mark.anyio +async def test_peer_with_higher_seniority_wins_and_we_switch_master( + fast_timeout: None, +) -> None: + """ + If a peer with clearly higher seniority participates in the round, they should win. + We should broadcast our status exactly once for this round, then switch master. + """ + em_out_tx, em_out_rx = channel[ElectionMessage]() + em_in_tx, em_in_rx = channel[ElectionMessage]() + er_tx, er_rx = channel[ElectionResult]() + cm_tx, cm_rx = channel[ConnectionMessage]() + + election = Election( + node_id=NodeId("ME"), + election_message_receiver=em_in_rx, + election_message_sender=em_out_tx, + election_result_sender=er_tx, + connection_message_receiver=cm_rx, + is_candidate=True, + ) + + async with create_task_group() as tg: + with fail_after(2): + tg.start_soon(election.run) + + # Start round with peer's message (higher seniority) + await em_in_tx.send(em(clock=1, seniority=10, node_id="PEER")) + + # We should still broadcast our status exactly once for this round + while True: + got = await em_out_rx.receive() + if got.clock == 1: + assert got.seniority == 0 + break + + # After the timeout, election result should report the peer as master + result = await er_rx.receive() + assert result.node_id == NodeId("PEER") + assert result.is_new_master is True + + await em_in_tx.aclose() + await cm_tx.aclose() + + # We lost → seniority unchanged + assert election.seniority == 0 + + +@pytest.mark.anyio +async def test_ignores_older_messages(fast_timeout: None) -> None: + """ + Messages with a lower clock than the current round are ignored by the receiver. + Expect exactly one broadcast for the higher clock round. + """ + em_out_tx, em_out_rx = channel[ElectionMessage]() + em_in_tx, em_in_rx = channel[ElectionMessage]() + er_tx, _er_rx = channel[ElectionResult]() + cm_tx, cm_rx = channel[ConnectionMessage]() + + election = Election( + node_id=NodeId("ME"), + election_message_receiver=em_in_rx, + election_message_sender=em_out_tx, + election_result_sender=er_tx, + connection_message_receiver=cm_rx, + is_candidate=True, + ) + + async with create_task_group() as tg: + with fail_after(2): + tg.start_soon(election.run) + + # Newer round arrives first -> triggers campaign at clock=2 + await em_in_tx.send(em(clock=2, seniority=0, node_id="A")) + while True: + first = await em_out_rx.receive() + if first.clock == 2: + break + + # Older message (clock=1) must be ignored (no second broadcast) + await em_in_tx.send(em(clock=1, seniority=999, node_id="B")) + + got_second = False + with move_on_after(0.2): + _ = await em_out_rx.receive() + got_second = True + assert not got_second, "Should not receive a broadcast for an older round" + + await em_in_tx.aclose() + await cm_tx.aclose() + + # Not asserting on the result; focus is on ignore behavior. + + +@pytest.mark.anyio +async def test_two_rounds_emit_two_broadcasts_and_increment_clock( + fast_timeout: None, +) -> None: + """ + Two successive rounds → two broadcasts. Second round triggered by a higher-clock message. + """ + em_out_tx, em_out_rx = channel[ElectionMessage]() + em_in_tx, em_in_rx = channel[ElectionMessage]() + er_tx, _er_rx = channel[ElectionResult]() + cm_tx, cm_rx = channel[ConnectionMessage]() + + election = Election( + node_id=NodeId("ME"), + election_message_receiver=em_in_rx, + election_message_sender=em_out_tx, + election_result_sender=er_tx, + connection_message_receiver=cm_rx, + is_candidate=True, + ) + + async with create_task_group() as tg: + with fail_after(2): + tg.start_soon(election.run) + + # Round 1 at clock=1 + await em_in_tx.send(em(clock=1, seniority=0, node_id="X")) + while True: + m1 = await em_out_rx.receive() + if m1.clock == 1: + break + + # Round 2 at clock=2 + await em_in_tx.send(em(clock=2, seniority=0, node_id="Y")) + while True: + m2 = await em_out_rx.receive() + if m2.clock == 2: + break + + await em_in_tx.aclose() + await cm_tx.aclose() + + # Not asserting on who won; just that both rounds were broadcast. + + +@pytest.mark.anyio +async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -> None: + """ + When we win against two peers in the same round, our seniority becomes + max(existing, number_of_candidates). With existing=0: expect 3 (us + A + B). + """ + em_out_tx, em_out_rx = channel[ElectionMessage]() + em_in_tx, em_in_rx = channel[ElectionMessage]() + er_tx, er_rx = channel[ElectionResult]() + cm_tx, cm_rx = channel[ConnectionMessage]() + + election = Election( + node_id=NodeId("ME"), + election_message_receiver=em_in_rx, + election_message_sender=em_out_tx, + election_result_sender=er_tx, + connection_message_receiver=cm_rx, + is_candidate=True, + ) + + async with create_task_group() as tg: + with fail_after(2): + tg.start_soon(election.run) + + # Start round at clock=7 with two peer participants + await em_in_tx.send(em(clock=7, seniority=0, node_id="A")) + await em_in_tx.send(em(clock=7, seniority=0, node_id="B")) + + # We should see exactly one broadcast from us for this round + while True: + got = await em_out_rx.receive() + if got.clock == 7 and got.node_id == NodeId("ME"): + break + + # Wait for the election to finish so seniority updates + _ = await er_rx.receive() + + await em_in_tx.aclose() + await cm_tx.aclose() + + # We + A + B = 3 → new seniority expected to be 3 + assert election.seniority == 3 + + +@pytest.mark.anyio +async def test_connection_message_triggers_new_round_broadcast( + fast_timeout: None, +) -> None: + """ + A connection message increments the clock and starts a new campaign. + We should observe a broadcast at the incremented clock. + """ + em_out_tx, em_out_rx = channel[ElectionMessage]() + em_in_tx, em_in_rx = channel[ElectionMessage]() + er_tx, _er_rx = channel[ElectionResult]() + cm_tx, cm_rx = channel[ConnectionMessage]() + + election = Election( + node_id=NodeId("ME"), + election_message_receiver=em_in_rx, + election_message_sender=em_out_tx, + election_result_sender=er_tx, + connection_message_receiver=cm_rx, + is_candidate=True, + ) + + async with create_task_group() as tg: + with fail_after(2): + tg.start_soon(election.run) + + # Send any connection message object; we close quickly to cancel before result creation + await cm_tx.send( + ConnectionMessage( + node_id=NodeId(), + connection_type=ConnectionMessageType.Connected, + remote_ipv4="", + remote_tcp_port=0, + ) + ) + + # Expect a broadcast for the new round at clock=1 + while True: + got = await em_out_rx.receive() + if got.clock == 1 and got.node_id == NodeId("ME"): + break + + # Close promptly to avoid waiting for campaign completion + await em_in_tx.aclose() + await cm_tx.aclose() + + # After cancellation (before election finishes), no seniority changes asserted here. diff --git a/src/exo/shared/tests/test_flock_mutex.py b/src/exo/shared/tests/test_flock_mutex.py index 42d68753..0dc1be4f 100644 --- a/src/exo/shared/tests/test_flock_mutex.py +++ b/src/exo/shared/tests/test_flock_mutex.py @@ -1,7 +1,7 @@ import pytest from exo.shared.ipc.file_mutex.flock_mutex import FlockMutex, LockType -from exo.shared.utils.fs import delete_if_exists, make_temp_path +from exo.utils.fs import delete_if_exists, make_temp_path def test_lock_held(): diff --git a/src/exo/shared/tests/test_node_id_persistence.py b/src/exo/shared/tests/test_node_id_persistence.py index 46a81d55..4633ab90 100644 --- a/src/exo/shared/tests/test_node_id_persistence.py +++ b/src/exo/shared/tests/test_node_id_persistence.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import contextlib import logging import multiprocessing @@ -13,8 +11,8 @@ from typing import Optional from pytest import LogCaptureFixture +from exo.routing.router import get_node_id_keypair from exo.shared.constants import EXO_NODE_ID_KEYPAIR -from exo.shared.keypair import get_node_id_keypair NUM_CONCURRENT_PROCS = 10 diff --git a/src/exo/shared/tests/test_sqlite_connector.py b/src/exo/shared/tests/test_sqlite_connector.py deleted file mode 100644 index 8917e9ce..00000000 --- a/src/exo/shared/tests/test_sqlite_connector.py +++ /dev/null @@ -1,612 +0,0 @@ -import asyncio -import json -import tempfile -from pathlib import Path -from typing import Any, Generator, cast -from uuid import uuid4 - -import pytest -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncSession - -from exo.shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig -from exo.shared.types.common import CommandId, NodeId -from exo.shared.types.events import ChunkGenerated -from exo.shared.types.events.chunks import ChunkType, TokenChunk - -# Type ignore comment for all protected member access in this test file -# pyright: reportPrivateUsage=false - - -def _load_json_data(raw_data: str) -> dict[str, Any]: - """Helper function to load JSON data with proper typing.""" - return cast(dict[str, Any], json.loads(raw_data)) - - -@pytest.fixture -def temp_db_path() -> Generator[Path, None, None]: - """Create a temporary database file for testing.""" - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: - yield Path(f.name) - # Cleanup - Path(f.name).unlink(missing_ok=True) - - -@pytest.fixture -def sample_node_id() -> NodeId: - """Create a sample NodeId for testing.""" - return NodeId() - - -class TestAsyncSQLiteEventStorage: - """Test suite for AsyncSQLiteEventStorage focused on storage functionality.""" - - @pytest.mark.asyncio - async def test_initialization_creates_tables(self, temp_db_path: Path) -> None: - """Test that database initialization creates the events table.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - # Verify table exists by querying directly - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text( - "SELECT name FROM sqlite_master WHERE type='table' AND name='events'" - ) - ) - tables = result.fetchall() - assert len(tables) == 1 - assert tables[0][0] == "events" - - await storage.close() - - @pytest.mark.asyncio - async def test_start_twice_raises_error(self, temp_db_path: Path) -> None: - """Test that starting storage twice raises an error.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - with pytest.raises(RuntimeError, match="Storage already started"): - await storage.start() - - await storage.close() - - @pytest.mark.asyncio - async def test_direct_database_operations( - self, temp_db_path: Path, sample_node_id: NodeId - ) -> None: - """Test direct database operations without event parsing.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - # Insert test data directly - test_data = { - "event_type": "test_event", - "test_field": "test_value", - "number": 42, - } - - async with AsyncSession(storage._engine) as session: - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": sample_node_id, - "event_type": "test_event", - "event_id": str(uuid4()), - "event_data": json.dumps(test_data), - }, - ) - await session.commit() - - # Query data back - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text("SELECT rowid, origin, event_data FROM events ORDER BY rowid") - ) - rows = result.fetchall() - - assert len(rows) == 1 - assert rows[0][0] == 1 # rowid - assert rows[0][1] == sample_node_id # origin - raw_json = cast(str, rows[0][2]) - retrieved_data = _load_json_data(raw_json) - assert retrieved_data == test_data - - await storage.close() - - @pytest.mark.asyncio - async def test_rowid_auto_increment( - self, temp_db_path: Path, sample_node_id: NodeId - ) -> None: - """Test that rowid auto-increments correctly.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - # Insert multiple records - test_records = [ - {"event_type": "test_event_1", "data": "first"}, - {"event_type": "test_event_2", "data": "second"}, - {"event_type": "test_event_3", "data": "third"}, - ] - - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - for record in test_records: - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": sample_node_id, - "event_type": record["event_type"], - "event_id": str(uuid4()), - "event_data": json.dumps(record), - }, - ) - await session.commit() - - # Query back and verify rowid sequence - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text("SELECT rowid, event_data FROM events ORDER BY rowid") - ) - rows = result.fetchall() - - assert len(rows) == 3 - for i, row in enumerate(rows): - assert row[0] == i + 1 # rowid starts at 1 - raw_json = cast(str, row[1]) - retrieved_data = _load_json_data(raw_json) - assert retrieved_data == test_records[i] - - await storage.close() - - @pytest.mark.asyncio - async def test_get_last_idx( - self, temp_db_path: Path, sample_node_id: NodeId - ) -> None: - """Test that rowid returns correctly from db.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - # Insert multiple records - test_records = [ - {"event_type": "test_event_1", "data": "first"}, - {"event_type": "test_event_2", "data": "second"}, - {"event_type": "test_event_3", "data": "third"}, - ] - - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - for record in test_records: - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": sample_node_id, - "event_type": record["event_type"], - "event_id": str(uuid4()), - "event_data": json.dumps(record), - }, - ) - await session.commit() - - last_idx = await storage.get_last_idx() - assert last_idx == 3 - - await storage.close() - - @pytest.mark.asyncio - async def test_rowid_with_multiple_origins(self, temp_db_path: Path) -> None: - """Test rowid sequence across multiple origins.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - origin1 = NodeId() - origin2 = NodeId() - - # Insert interleaved records from different origins - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - # Origin 1 - record 1 - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": origin1, - "event_type": "event_1", - "event_id": str(uuid4()), - "event_data": json.dumps({"from": "origin1", "seq": 1}), - }, - ) - # Origin 2 - record 2 - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": origin2, - "event_type": "event_2", - "event_id": str(uuid4()), - "event_data": json.dumps({"from": "origin2", "seq": 2}), - }, - ) - # Origin 1 - record 3 - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": origin1, - "event_type": "event_3", - "event_id": str(uuid4()), - "event_data": json.dumps({"from": "origin1", "seq": 3}), - }, - ) - await session.commit() - - # Verify sequential rowid regardless of origin - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text("SELECT rowid, origin, event_data FROM events ORDER BY rowid") - ) - rows = result.fetchall() - - assert len(rows) == 3 - assert rows[0][0] == 1 # First rowid - assert rows[1][0] == 2 # Second rowid - assert rows[2][0] == 3 # Third rowid - - # Verify data integrity - raw_json1 = cast(str, rows[0][2]) - raw_json2 = cast(str, rows[1][2]) - raw_json3 = cast(str, rows[2][2]) - data1 = _load_json_data(raw_json1) - data2 = _load_json_data(raw_json2) - data3 = _load_json_data(raw_json3) - - assert data1["from"] == "origin1" and data1["seq"] == 1 - assert data2["from"] == "origin2" and data2["seq"] == 2 - assert data3["from"] == "origin1" and data3["seq"] == 3 - - await storage.close() - - @pytest.mark.asyncio - async def test_query_events_since_index( - self, temp_db_path: Path, sample_node_id: NodeId - ) -> None: - """Test querying events after a specific rowid.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - # Insert 10 test records - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - for i in range(10): - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": sample_node_id, - "event_type": f"event_{i}", - "event_id": str(uuid4()), - "event_data": json.dumps({"index": i}), - }, - ) - await session.commit() - - # Query events after index 5 - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text( - "SELECT rowid, event_data FROM events WHERE rowid > :last_idx ORDER BY rowid" - ), - {"last_idx": 5}, - ) - rows = result.fetchall() - - assert len(rows) == 5 # Should get records 6-10 - for i, row in enumerate(rows): - assert row[0] == i + 6 # rowid 6, 7, 8, 9, 10 - raw_json = cast(str, row[1]) - data = _load_json_data(raw_json) - assert data["index"] == i + 5 # index 5, 6, 7, 8, 9 - - await storage.close() - - @pytest.mark.asyncio - async def test_empty_query(self, temp_db_path: Path) -> None: - """Test querying when no events exist.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text( - "SELECT rowid, origin, event_data FROM events WHERE rowid > :last_idx ORDER BY rowid" - ), - {"last_idx": 0}, - ) - rows = result.fetchall() - - assert len(rows) == 0 - - await storage.close() - - @pytest.mark.asyncio - async def test_operations_after_close_raise_error(self, temp_db_path: Path) -> None: - """Test that operations after close work properly.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - await storage.close() - - # These should not raise errors since we're not using the public API - assert storage._closed is True - assert storage._engine is not None # Engine should still exist but be disposed - - @pytest.mark.asyncio - async def test_multiple_close_calls_safe(self, temp_db_path: Path) -> None: - """Test that multiple close calls are safe.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - await storage.close() - await storage.close() # Should not raise an error - - @pytest.mark.asyncio - async def test_json_data_types( - self, temp_db_path: Path, sample_node_id: NodeId - ) -> None: - """Test that various JSON data types are handled correctly.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - # Test various JSON data types - test_data = { - "string": "test string", - "number": 42, - "float": 3.14, - "boolean": True, - "null": None, - "array": [1, 2, 3, "four"], - "object": {"nested": "value", "deep": {"deeper": "nested"}}, - "unicode": "测试 🚀", - } - - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": sample_node_id, - "event_type": "complex_event", - "event_id": str(uuid4()), - "event_data": json.dumps(test_data), - }, - ) - await session.commit() - - # Query back and verify data integrity - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text("SELECT event_data FROM events WHERE event_type = :event_type"), - {"event_type": "complex_event"}, - ) - rows = result.fetchall() - - assert len(rows) == 1 - raw_json = cast(str, rows[0][0]) - retrieved_data = _load_json_data(raw_json) - assert retrieved_data == test_data - - await storage.close() - - @pytest.mark.asyncio - async def test_concurrent_inserts(self, temp_db_path: Path) -> None: - """Test concurrent inserts maintain rowid ordering.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - async def insert_batch(origin_id: str, batch_id: int, count: int) -> None: - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - for i in range(count): - await session.execute( - text( - "INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)" - ), - { - "origin": origin_id, - "event_type": f"batch_{batch_id}_event_{i}", - "event_id": str(uuid4()), - "event_data": json.dumps({"batch": batch_id, "item": i}), - }, - ) - await session.commit() - - # Run multiple concurrent insert batches - origin1 = str(uuid4()) - origin2 = str(uuid4()) - origin3 = str(uuid4()) - - await asyncio.gather( - insert_batch(origin1, 1, 5), - insert_batch(origin2, 2, 5), - insert_batch(origin3, 3, 5), - ) - - # Verify all records were inserted and rowid is sequential - assert storage._engine is not None - async with AsyncSession(storage._engine) as session: - result = await session.execute( - text("SELECT rowid, origin, event_data FROM events ORDER BY rowid") - ) - rows = result.fetchall() - - assert len(rows) == 15 # 3 batches * 5 records each - - # Verify rowid sequence is maintained - for i, row in enumerate(rows): - assert row[0] == i + 1 # rowid should be sequential - - await storage.close() - - @pytest.mark.asyncio - async def test_chunk_generated_event_serialization( - self, temp_db_path: Path, sample_node_id: NodeId - ) -> None: - """Test that ChunkGenerated event with nested types can be serialized and deserialized correctly.""" - default_config = EventLogConfig() - storage = AsyncSQLiteEventStorage( - db_path=temp_db_path, - batch_size=default_config.batch_size, - batch_timeout_ms=default_config.batch_timeout_ms, - debounce_ms=default_config.debounce_ms, - max_age_ms=default_config.max_age_ms, - ) - await storage.start() - - # Create a ChunkGenerated event with nested TokenChunk - command_id = CommandId() - token_chunk = TokenChunk( - text="Hello, world!", - token_id=42, - finish_reason="stop", - chunk_type=ChunkType.token, - command_id=command_id, - idx=0, - model="test-model", - ) - - chunk_generated_event = ChunkGenerated(command_id=command_id, chunk=token_chunk) - - # Store the event using the storage API - await storage.append_events([chunk_generated_event], sample_node_id) - - # Wait for batch to be written - await asyncio.sleep(0.5) - - # Retrieve the event - events = await storage.get_events_since(0) - - # Verify we got the event back - assert len(events) == 1 - retrieved_event_wrapper = events[0] - assert retrieved_event_wrapper.origin == sample_node_id - - # Verify the event was deserialized correctly - retrieved_event = retrieved_event_wrapper.event - assert isinstance(retrieved_event, ChunkGenerated) - assert retrieved_event.command_id == command_id - - # Verify the nested chunk was deserialized correctly - retrieved_chunk = retrieved_event.chunk - assert isinstance(retrieved_chunk, TokenChunk) - assert retrieved_chunk.chunk_type == ChunkType.token - assert retrieved_chunk.command_id == command_id - assert retrieved_chunk.idx == 0 - assert retrieved_chunk.model == "test-model" - - # Verify the chunk data - assert retrieved_chunk.text == "Hello, world!" - assert retrieved_chunk.token_id == 42 - assert retrieved_chunk.finish_reason == "stop" - - await storage.close() diff --git a/src/exo/shared/tests/test_state_serialization.py b/src/exo/shared/tests/test_state_serialization.py index 2497c437..5935d444 100644 --- a/src/exo/shared/tests/test_state_serialization.py +++ b/src/exo/shared/tests/test_state_serialization.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from exo.shared.types.common import NodeId from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.state import State @@ -16,13 +14,11 @@ def test_state_serialization_roundtrip() -> None: connection = Connection( local_node_id=node_a, send_back_node_id=node_b, - local_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10000"), send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"), ) state = State() state.topology.add_connection(connection) - state.topology.master_node_id = node_a json_repr = state.model_dump_json() restored_state = State.model_validate_json(json_repr) diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index a3825a27..5be5af86 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -5,38 +5,33 @@ import rustworkx as rx from pydantic import BaseModel, ConfigDict from exo.shared.types.common import NodeId -from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile -from exo.shared.types.topology import Connection, Node, TopologyProto +from exo.shared.types.topology import Connection, NodeInfo class TopologySnapshot(BaseModel): - nodes: list[Node] + nodes: list[NodeInfo] connections: list[Connection] - master_node_id: NodeId | None = None model_config = ConfigDict(frozen=True, extra="forbid", strict=True) -class Topology(TopologyProto): +class Topology: def __init__(self) -> None: - self._graph: rx.PyDiGraph[Node, Connection] = rx.PyDiGraph() + self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph() self._node_id_to_rx_id_map: dict[NodeId, int] = dict() self._rx_id_to_node_id_map: dict[int, NodeId] = dict() self._edge_id_to_rx_id_map: dict[Connection, int] = dict() - self.master_node_id: NodeId | None = None def to_snapshot(self) -> TopologySnapshot: return TopologySnapshot( nodes=list(self.list_nodes()), connections=list(self.list_connections()), - master_node_id=self.master_node_id, ) @classmethod def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology": topology = cls() - topology.master_node_id = snapshot.master_node_id for node in snapshot.nodes: with contextlib.suppress(ValueError): @@ -47,16 +42,13 @@ class Topology(TopologyProto): return topology - def add_node(self, node: Node) -> None: + def add_node(self, node: NodeInfo) -> None: if node.node_id in self._node_id_to_rx_id_map: return rx_id = self._graph.add_node(node) self._node_id_to_rx_id_map[node.node_id] = rx_id self._rx_id_to_node_id_map[rx_id] = node.node_id - def set_master_node_id(self, node_id: NodeId) -> None: - self.master_node_id = node_id - def contains_node(self, node_id: NodeId) -> bool: return node_id in self._node_id_to_rx_id_map @@ -68,9 +60,9 @@ class Topology(TopologyProto): connection: Connection, ) -> None: if connection.local_node_id not in self._node_id_to_rx_id_map: - self.add_node(Node(node_id=connection.local_node_id)) + self.add_node(NodeInfo(node_id=connection.local_node_id)) if connection.send_back_node_id not in self._node_id_to_rx_id_map: - self.add_node(Node(node_id=connection.send_back_node_id)) + self.add_node(NodeInfo(node_id=connection.send_back_node_id)) src_id = self._node_id_to_rx_id_map[connection.local_node_id] sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id] @@ -78,12 +70,11 @@ class Topology(TopologyProto): rx_id = self._graph.add_edge(src_id, sink_id, connection) self._edge_id_to_rx_id_map[connection] = rx_id - def list_nodes(self) -> Iterable[Node]: - yield from (self._graph[i] for i in self._graph.node_indices()) + def list_nodes(self) -> Iterable[NodeInfo]: + return (self._graph[i] for i in self._graph.node_indices()) def list_connections(self) -> Iterable[Connection]: - for _, _, connection in self._graph.weighted_edge_list(): - yield connection + return (connection for _, _, connection in self._graph.weighted_edge_list()) def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None: try: @@ -92,14 +83,6 @@ class Topology(TopologyProto): except KeyError: return None - def get_node_multiaddr(self, node_id: NodeId) -> Multiaddr: - for connection in self.list_connections(): - if connection.local_node_id == node_id: - return connection.local_multiaddr - if connection.send_back_node_id == node_id: - return connection.send_back_multiaddr - raise ValueError(f"Node {node_id} is not connected to any other nodes") - def update_node_profile( self, node_id: NodeId, node_profile: NodePerformanceProfile ) -> None: @@ -128,37 +111,40 @@ class Topology(TopologyProto): def remove_connection(self, connection: Connection) -> None: rx_idx = self._edge_id_to_rx_id_map[connection] - if self._is_bridge(connection): - # Determine the reference node from which reachability is calculated. - # Prefer a master node if the topology knows one; otherwise fall back to - # the local end of the connection being removed. - reference_node_id: NodeId = ( - self.master_node_id - if self.master_node_id is not None - else connection.local_node_id - ) - orphan_node_ids = self._get_orphan_node_ids(reference_node_id, connection) - for orphan_node_id in orphan_node_ids: - orphan_node_rx_id = self._node_id_to_rx_id_map[orphan_node_id] - self._graph.remove_node(orphan_node_rx_id) - del self._node_id_to_rx_id_map[orphan_node_id] - del self._rx_id_to_node_id_map[orphan_node_rx_id] - self._graph.remove_edge_from_index(rx_idx) del self._edge_id_to_rx_id_map[connection] - if rx_idx in self._rx_id_to_node_id_map: - del self._rx_id_to_node_id_map[rx_idx] - def get_cycles(self) -> list[list[Node]]: + def get_cycles(self) -> list[list[NodeInfo]]: cycle_idxs = rx.simple_cycles(self._graph) - cycles: list[list[Node]] = [] + cycles: list[list[NodeInfo]] = [] for cycle_idx in cycle_idxs: cycle = [self._graph[idx] for idx in cycle_idx] cycles.append(cycle) return cycles - def get_subgraph_from_nodes(self, nodes: list[Node]) -> "Topology": + def get_cycles_tb(self) -> list[list[NodeInfo]]: + tb_edges = [ + (u, v, conn) + for u, v, conn in self._graph.weighted_edge_list() + if conn.is_thunderbolt() + ] + + tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph() + tb_graph.add_nodes_from(self._graph.nodes()) + + for u, v, conn in tb_edges: + tb_graph.add_edge(u, v, conn) + + cycle_idxs = rx.simple_cycles(tb_graph) + cycles: list[list[NodeInfo]] = [] + for cycle_idx in cycle_idxs: + cycle = [tb_graph[idx] for idx in cycle_idx] + cycles.append(cycle) + + return cycles + + def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology": node_idxs = [node.node_id for node in nodes] rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs] topology = Topology() @@ -172,7 +158,7 @@ class Topology(TopologyProto): topology.add_connection(connection) return topology - def is_thunderbolt_cycle(self, cycle: list[Node]) -> bool: + def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool: node_idxs = [node.node_id for node in cycle] rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs] for rid in rx_idxs: @@ -187,49 +173,3 @@ class Topology(TopologyProto): if not has_tb: return False return True - - def _is_bridge(self, connection: Connection) -> bool: - """Check if removing this connection will orphan any nodes from the master.""" - if self.master_node_id is None: - return False - - orphan_node_ids = self._get_orphan_node_ids(self.master_node_id, connection) - return len(orphan_node_ids) > 0 - - def _get_orphan_node_ids( - self, master_node_id: NodeId, connection: Connection - ) -> list[NodeId]: - """Return node_ids that become unreachable from `master_node_id` once `connection` is removed. - - A node is considered *orphaned* if there exists **no directed path** from - the master node to that node after deleting the edge identified by - ``connection``. This definition is strictly weaker than being in a - different *strongly* connected component and more appropriate for - directed networks where information only needs to flow *outwards* from - the master. - """ - edge_idx = self._edge_id_to_rx_id_map[connection] - # Operate on a copy so the original topology remains intact while we - # compute reachability. - graph_copy: rx.PyDiGraph[Node, Connection] = self._graph.copy() - graph_copy.remove_edge_from_index(edge_idx) - - if master_node_id not in self._node_id_to_rx_id_map: - # If the provided master node isn't present we conservatively treat - # every other node as orphaned. - return list(self._node_id_to_rx_id_map.keys()) - - master_rx_id = self._node_id_to_rx_id_map[master_node_id] - - # Nodes reachable by following outgoing edges from the master. - reachable_rx_ids: set[int] = set(rx.descendants(graph_copy, master_rx_id)) - reachable_rx_ids.add(master_rx_id) - - # Every existing node index not reachable is orphaned. - orphan_rx_ids = set(graph_copy.node_indices()) - reachable_rx_ids - - return [ - self._rx_id_to_node_id_map[rx_id] - for rx_id in orphan_rx_ids - if rx_id in self._rx_id_to_node_id_map - ] diff --git a/src/exo/shared/types/__init__.py b/src/exo/shared/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py index 22870e63..f91e315e 100644 --- a/src/exo/shared/types/api.py +++ b/src/exo/shared/types/api.py @@ -133,7 +133,6 @@ class CreateInstanceResponse(BaseModel): message: str command_id: CommandId model_meta: ModelMetadata - instance_id: InstanceId class DeleteInstanceResponse(BaseModel): diff --git a/src/exo/shared/types/chunks.py b/src/exo/shared/types/chunks.py new file mode 100644 index 00000000..ec7a8295 --- /dev/null +++ b/src/exo/shared/types/chunks.py @@ -0,0 +1,35 @@ +from enum import Enum +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + +from exo.shared.openai_compat import FinishReason +from exo.shared.types.common import CommandId +from exo.shared.types.models import ModelId + + +class ChunkType(str, Enum): + token = "token" + image = "image" + + +class BaseChunk[ChunkTypeT: ChunkType](BaseModel): + chunk_type: ChunkTypeT + command_id: CommandId + idx: int + model: ModelId + + +class TokenChunk(BaseChunk[ChunkType.token]): + chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True) + text: str + token_id: int + finish_reason: FinishReason | None = None + + +class ImageChunk(BaseChunk[ChunkType.image]): + chunk_type: Literal[ChunkType.image] = Field(default=ChunkType.image, frozen=True) + data: bytes + + +GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")] diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py new file mode 100644 index 00000000..4c9a1066 --- /dev/null +++ b/src/exo/shared/types/commands.py @@ -0,0 +1,78 @@ +from enum import Enum +from typing import Union + +from pydantic import Field + +from exo.shared.types.api import ChatCompletionTaskParams +from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.models import ModelMetadata +from exo.shared.types.worker.common import InstanceId +from exo.utils.pydantic_ext import CamelCaseModel +from exo.utils.pydantic_tagged import Tagged, tagged_union + + +# TODO: We need to have a distinction between create instance and spin up instance. +class CommandType(str, Enum): + ChatCompletion = "ChatCompletion" + CreateInstance = "CreateInstance" + SpinUpInstance = "SpinUpInstance" + DeleteInstance = "DeleteInstance" + TaskFinished = "TaskFinished" + RequestEventLog = "RequestEventLog" + + +class BaseCommand(CamelCaseModel): + command_id: CommandId = Field(default_factory=CommandId) + + +class ChatCompletion(BaseCommand): + request_params: ChatCompletionTaskParams + + +class CreateInstance(BaseCommand): + model_meta: ModelMetadata + + +class SpinUpInstance(BaseCommand): + instance_id: InstanceId + + +class DeleteInstance(BaseCommand): + instance_id: InstanceId + + +class TaskFinished(BaseCommand): + finished_command_id: CommandId + + +class RequestEventLog(BaseCommand): + since_idx: int + + +Command = Union[ + RequestEventLog, + ChatCompletion, + CreateInstance, + SpinUpInstance, + DeleteInstance, + TaskFinished, +] + + +@tagged_union( + { + CommandType.ChatCompletion: ChatCompletion, + CommandType.CreateInstance: CreateInstance, + CommandType.SpinUpInstance: SpinUpInstance, + CommandType.DeleteInstance: DeleteInstance, + CommandType.TaskFinished: TaskFinished, + CommandType.RequestEventLog: RequestEventLog, + } +) +class TaggedCommand(Tagged[Command]): + pass + + +class ForwarderCommand(CamelCaseModel): + origin: NodeId + tagged_command: TaggedCommand diff --git a/src/exo/shared/types/common.py b/src/exo/shared/types/common.py index bc7cd127..b89ff915 100644 --- a/src/exo/shared/types/common.py +++ b/src/exo/shared/types/common.py @@ -1,5 +1,4 @@ -from ipaddress import IPv4Address, IPv6Address -from typing import Any, Self +from typing import Self from uuid import uuid4 from pydantic import BaseModel, GetCoreSchemaHandler, field_validator @@ -12,10 +11,10 @@ class ID(str): @classmethod def __get_pydantic_core_schema__( - cls, _source: type[Any], handler: GetCoreSchemaHandler + cls, _source: type, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: - # Re‑use the already‑defined schema for `str` - return handler.generate_schema(str) + # Just use a plain string schema + return core_schema.str_schema() class NodeId(ID): @@ -27,7 +26,7 @@ class CommandId(ID): class Host(BaseModel): - ip: IPv4Address | IPv6Address + ip: str port: int def __str__(self) -> str: diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py new file mode 100644 index 00000000..8d9aa32c --- /dev/null +++ b/src/exo/shared/types/events.py @@ -0,0 +1,199 @@ +from enum import Enum +from typing import Union + +from pydantic import Field + +from exo.shared.topology import Connection, NodePerformanceProfile +from exo.shared.types.chunks import CommandId, GenerationChunk +from exo.shared.types.common import ID, NodeId +from exo.shared.types.tasks import Task, TaskId, TaskStatus +from exo.shared.types.worker.common import InstanceId, WorkerStatus +from exo.shared.types.worker.instances import Instance +from exo.shared.types.worker.runners import RunnerId, RunnerStatus +from exo.utils.pydantic_ext import CamelCaseModel +from exo.utils.pydantic_tagged import Tagged, tagged_union + + +class EventId(ID): + """ + Newtype around `ID` + """ + + +class EventType(str, Enum): + """ + Here are all the unique kinds of events that can be sent over the network. + """ + + # Test Events, strictly for mocks and tests. + TestEvent = "TestEvent" + + # Task Events + TaskCreated = "TaskCreated" + TaskStateUpdated = "TaskStateUpdated" + TaskFailed = "TaskFailed" + TaskDeleted = "TaskDeleted" + + # Streaming Events + ChunkGenerated = "ChunkGenerated" + + # Instance Events + InstanceCreated = "InstanceCreated" + InstanceDeleted = "InstanceDeleted" + InstanceActivated = "InstanceActivated" + InstanceDeactivated = "InstanceDeactivated" + InstanceReplacedAtomically = "InstanceReplacedAtomically" + + # Runner Status Events + RunnerStatusUpdated = "RunnerStatusUpdated" + RunnerDeleted = "RunnerDeleted" + + # Node Performance Events + WorkerStatusUpdated = "WorkerStatusUpdated" + NodePerformanceMeasured = "NodePerformanceMeasured" + + # Topology Events + TopologyNodeCreated = "TopologyNodeCreated" + TopologyEdgeCreated = "TopologyEdgeCreated" + TopologyEdgeDeleted = "TopologyEdgeDeleted" + + +class BaseEvent(CamelCaseModel): + event_id: EventId = Field(default_factory=EventId) + + +class TestEvent(BaseEvent): + pass + + +class TaskCreated(BaseEvent): + task_id: TaskId + task: Task + + +class TaskDeleted(BaseEvent): + task_id: TaskId + + +class TaskStateUpdated(BaseEvent): + task_id: TaskId + task_status: TaskStatus + + +class TaskFailed(BaseEvent): + task_id: TaskId + error_type: str + error_message: str + + +class InstanceCreated(BaseEvent): + instance: Instance + + +class InstanceActivated(BaseEvent): + instance_id: InstanceId + + +class InstanceDeactivated(BaseEvent): + instance_id: InstanceId + + +class InstanceDeleted(BaseEvent): + instance_id: InstanceId + + +class RunnerStatusUpdated(BaseEvent): + runner_id: RunnerId + runner_status: RunnerStatus + + +class RunnerDeleted(BaseEvent): + runner_id: RunnerId + + +class NodePerformanceMeasured(BaseEvent): + node_id: NodeId + node_profile: NodePerformanceProfile + + +class WorkerStatusUpdated(BaseEvent): + node_id: NodeId + node_state: WorkerStatus + + +class ChunkGenerated(BaseEvent): + command_id: CommandId + chunk: GenerationChunk + + +class TopologyNodeCreated(BaseEvent): + node_id: NodeId + + +class TopologyEdgeCreated(BaseEvent): + edge: Connection + + +class TopologyEdgeDeleted(BaseEvent): + edge: Connection + + +Event = Union[ + TestEvent, + TaskCreated, + TaskStateUpdated, + TaskFailed, + TaskDeleted, + InstanceCreated, + InstanceActivated, + InstanceDeactivated, + InstanceDeleted, + RunnerStatusUpdated, + RunnerDeleted, + NodePerformanceMeasured, + WorkerStatusUpdated, + ChunkGenerated, + TopologyNodeCreated, + TopologyEdgeCreated, + TopologyEdgeDeleted, +] + + +@tagged_union( + { + EventType.TestEvent: TestEvent, + EventType.TaskCreated: TaskCreated, + EventType.TaskStateUpdated: TaskStateUpdated, + EventType.TaskFailed: TaskFailed, + EventType.TaskDeleted: TaskDeleted, + EventType.InstanceCreated: InstanceCreated, + EventType.InstanceActivated: InstanceActivated, + EventType.InstanceDeactivated: InstanceDeactivated, + EventType.InstanceDeleted: InstanceDeleted, + EventType.RunnerStatusUpdated: RunnerStatusUpdated, + EventType.RunnerDeleted: RunnerDeleted, + EventType.NodePerformanceMeasured: NodePerformanceMeasured, + EventType.WorkerStatusUpdated: WorkerStatusUpdated, + EventType.ChunkGenerated: ChunkGenerated, + EventType.TopologyNodeCreated: TopologyNodeCreated, + EventType.TopologyEdgeCreated: TopologyEdgeCreated, + EventType.TopologyEdgeDeleted: TopologyEdgeDeleted, + } +) +class TaggedEvent(Tagged[Event]): + pass + + +class IndexedEvent(CamelCaseModel): + """An event indexed by the master, with a globally unique index""" + + idx: int = Field(ge=0) + event: Event + + +class ForwarderEvent(CamelCaseModel): + """An event the forwarder will serialize and send over the network""" + + origin_idx: int = Field(ge=0) + origin: NodeId + tagged_event: TaggedEvent diff --git a/src/exo/shared/types/events/__init__.py b/src/exo/shared/types/events/__init__.py deleted file mode 100644 index 462d460c..00000000 --- a/src/exo/shared/types/events/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# ruff: noqa: F403 -# ruff: noqa: F405 - -# Note: we are implementing internal details here, so importing private stuff is fine!!! -from pydantic import TypeAdapter - -from ._events import * -from .components import EventFromEventLog - -EventParser: TypeAdapter[Event] = TypeAdapter(Event) -"""Type adaptor to parse :class:`Event`s.""" - -__all__ = ["Event", "EventParser", "EventFromEventLog"] diff --git a/src/exo/shared/types/events/_events.py b/src/exo/shared/types/events/_events.py deleted file mode 100644 index dccb9f6f..00000000 --- a/src/exo/shared/types/events/_events.py +++ /dev/null @@ -1,340 +0,0 @@ -import types -from enum import Enum -from typing import ( - TYPE_CHECKING, - Annotated, - Any, - Literal, - TypeVar, - Union, - get_args, - get_origin, - get_type_hints, -) - -from pydantic import Field - -from exo.shared.constants import get_error_reporting_message -from exo.shared.topology import Connection, ConnectionProfile, NodePerformanceProfile -from exo.shared.types.common import NodeId -from exo.shared.types.events.chunks import CommandId, GenerationChunk -from exo.shared.types.tasks import Task, TaskId, TaskStatus -from exo.shared.types.worker.common import InstanceId, NodeStatus -from exo.shared.types.worker.instances import Instance -from exo.shared.types.worker.runners import RunnerId, RunnerStatus - -if TYPE_CHECKING: - pass - -from pydantic import BaseModel - -from exo.shared.types.common import ID - - -class EventId(ID): - """ - Newtype around `ID` - """ - - -# Event base-class boilerplate (you should basically never touch these) -# Only very specialised registry or serialisation/deserialization logic might need know about these - - -class _EventType(str, Enum): - """ - Here are all the unique kinds of events that can be sent over the network. - """ - - # Heartbeat Events - Heartbeat = "Heartbeat" - - # Task Events - TaskCreated = "TaskCreated" - TaskStateUpdated = "TaskStateUpdated" - TaskFailed = "TaskFailed" - TaskDeleted = "TaskDeleted" - - # Streaming Events - ChunkGenerated = "ChunkGenerated" - - # Instance Events - InstanceCreated = "InstanceCreated" - InstanceDeleted = "InstanceDeleted" - InstanceActivated = "InstanceActivated" - InstanceDeactivated = "InstanceDeactivated" - InstanceReplacedAtomically = "InstanceReplacedAtomically" - - # Runner Status Events - RunnerStatusUpdated = "RunnerStatusUpdated" - RunnerDeleted = "RunnerDeleted" - - # Node Performance Events - NodePerformanceMeasured = "NodePerformanceMeasured" - - # Topology Events - TopologyNodeCreated = "TopologyNodeCreated" - TopologyEdgeCreated = "TopologyEdgeCreated" - TopologyEdgeReplacedAtomically = "TopologyEdgeReplacedAtomically" - TopologyEdgeDeleted = "TopologyEdgeDeleted" - WorkerStatusUpdated = "WorkerStatusUpdated" - - # # Timer Events - # TimerCreated = "TimerCreated" - # TimerFired = "TimerFired" - - -class _BaseEvent[T: _EventType](BaseModel): - """ - This is the event base-class, to please the Pydantic gods. - PLEASE don't use this for anything unless you know why you are doing so, - instead just use the events union :) - """ - - event_type: T - event_id: EventId = EventId() - __no_apply__: bool = False - - def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: - """Check if the event was sent by the correct node. - - This is a placeholder implementation that always returns True. - Subclasses can override this method to implement specific validation logic. - """ - return True - - -_E = TypeVar("_E", bound=_BaseEvent[Any]) - - -def no_op_event(cls: type[_E]) -> type[_E]: - """Decorator to mark an event class as a *no-op*. - - Events marked as no-ops do not require an `event_apply` registration – the - apply layer will simply return the current state unchanged. This reduces - boilerplate and keeps console output quieter for high-frequency events - such as *Heartbeat* or streaming *ChunkGenerated* messages. - """ - - cls.__no_apply__ = True # Used by the apply layer to identify no-op events - return cls - - -@no_op_event -class Heartbeat(_BaseEvent[_EventType.Heartbeat]): - event_type: Literal[_EventType.Heartbeat] = _EventType.Heartbeat - node_id: NodeId - - -class TaskCreated(_BaseEvent[_EventType.TaskCreated]): - event_type: Literal[_EventType.TaskCreated] = _EventType.TaskCreated - task_id: TaskId - task: Task - - -class TaskDeleted(_BaseEvent[_EventType.TaskDeleted]): - event_type: Literal[_EventType.TaskDeleted] = _EventType.TaskDeleted - task_id: TaskId - - -class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]): - event_type: Literal[_EventType.TaskStateUpdated] = _EventType.TaskStateUpdated - task_id: TaskId - task_status: TaskStatus - - -class TaskFailed(_BaseEvent[_EventType.TaskFailed]): - event_type: Literal[_EventType.TaskFailed] = _EventType.TaskFailed - task_id: TaskId - error_type: str - error_message: str - - -class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]): - event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated - instance: Instance - - -class InstanceActivated(_BaseEvent[_EventType.InstanceActivated]): - event_type: Literal[_EventType.InstanceActivated] = _EventType.InstanceActivated - instance_id: InstanceId - - -class InstanceDeactivated(_BaseEvent[_EventType.InstanceDeactivated]): - event_type: Literal[_EventType.InstanceDeactivated] = _EventType.InstanceDeactivated - instance_id: InstanceId - - -class InstanceDeleted(_BaseEvent[_EventType.InstanceDeleted]): - event_type: Literal[_EventType.InstanceDeleted] = _EventType.InstanceDeleted - instance_id: InstanceId - - -class InstanceReplacedAtomically(_BaseEvent[_EventType.InstanceReplacedAtomically]): - event_type: Literal[_EventType.InstanceReplacedAtomically] = ( - _EventType.InstanceReplacedAtomically - ) - instance_to_replace: InstanceId - new_instance_id: InstanceId - - -# TODO: RunnerCreated - - -class RunnerStatusUpdated(_BaseEvent[_EventType.RunnerStatusUpdated]): - event_type: Literal[_EventType.RunnerStatusUpdated] = _EventType.RunnerStatusUpdated - runner_id: RunnerId - runner_status: RunnerStatus - - -class RunnerDeleted(_BaseEvent[_EventType.RunnerDeleted]): - event_type: Literal[_EventType.RunnerDeleted] = _EventType.RunnerDeleted - runner_id: RunnerId - - -class NodePerformanceMeasured(_BaseEvent[_EventType.NodePerformanceMeasured]): - event_type: Literal[_EventType.NodePerformanceMeasured] = ( - _EventType.NodePerformanceMeasured - ) - node_id: NodeId - node_profile: NodePerformanceProfile - - -class WorkerStatusUpdated(_BaseEvent[_EventType.WorkerStatusUpdated]): - event_type: Literal[_EventType.WorkerStatusUpdated] = _EventType.WorkerStatusUpdated - node_id: NodeId - node_state: NodeStatus - - -@no_op_event -class ChunkGenerated(_BaseEvent[_EventType.ChunkGenerated]): - event_type: Literal[_EventType.ChunkGenerated] = _EventType.ChunkGenerated - command_id: CommandId - chunk: GenerationChunk - - -class TopologyNodeCreated(_BaseEvent[_EventType.TopologyNodeCreated]): - event_type: Literal[_EventType.TopologyNodeCreated] = _EventType.TopologyNodeCreated - node_id: NodeId - role: Literal["MASTER", "REPLICA"] - - -class TopologyEdgeCreated(_BaseEvent[_EventType.TopologyEdgeCreated]): - event_type: Literal[_EventType.TopologyEdgeCreated] = _EventType.TopologyEdgeCreated - edge: Connection - - -class TopologyEdgeReplacedAtomically( - _BaseEvent[_EventType.TopologyEdgeReplacedAtomically] -): - """ - TODO: delete this???? - """ - - event_type: Literal[_EventType.TopologyEdgeReplacedAtomically] = ( - _EventType.TopologyEdgeReplacedAtomically - ) - edge: Connection - edge_profile: ConnectionProfile - - -class TopologyEdgeDeleted(_BaseEvent[_EventType.TopologyEdgeDeleted]): - event_type: Literal[_EventType.TopologyEdgeDeleted] = _EventType.TopologyEdgeDeleted - edge: Connection - - -_Event = Union[ - Heartbeat, - TaskCreated, - TaskStateUpdated, - TaskFailed, - TaskDeleted, - InstanceCreated, - InstanceActivated, - InstanceDeactivated, - InstanceDeleted, - InstanceReplacedAtomically, - RunnerStatusUpdated, - RunnerDeleted, - NodePerformanceMeasured, - WorkerStatusUpdated, - ChunkGenerated, - TopologyNodeCreated, - TopologyEdgeCreated, - TopologyEdgeReplacedAtomically, - TopologyEdgeDeleted, -] -""" -Un-annotated union of all events. Only used internally to create the registry. -For all other usecases, use the annotated union of events :class:`Event` :) -""" - - -def _check_event_type_consistency(): - # Grab enum values from members - member_enum_values = [m for m in _EventType] - - # grab enum values from the union => scrape the type annotation - union_enum_values: list[_EventType] = [] - union_classes = list(get_args(_Event)) - for cls in union_classes: # pyright: ignore[reportAny] - assert issubclass(cls, object), ( - f"{get_error_reporting_message()}", - f"The class {cls} is NOT a subclass of {object}.", - ) - - # ensure the first base parameter is ALWAYS _BaseEvent - base_cls = list(types.get_original_bases(cls)) - assert ( - len(base_cls) >= 1 - and issubclass(base_cls[0], object) - and issubclass(base_cls[0], _BaseEvent) - ), ( - f"{get_error_reporting_message()}", - f"The class {cls} does NOT inherit from {_BaseEvent} {get_origin(base_cls[0])}.", - ) - - # grab type hints and extract the right values from it - cls_hints = get_type_hints(cls) - assert ( - "event_type" in cls_hints and get_origin(cls_hints["event_type"]) is Literal # type: ignore - ), ( - f"{get_error_reporting_message()}", - f"The class {cls} is missing a {Literal}-annotated `event_type` field.", - ) - - # make sure the value is an instance of `_EventType` - enum_value = list(get_args(cls_hints["event_type"])) - assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), ( - f"{get_error_reporting_message()}", - f"The `event_type` of {cls} has a non-{_EventType} literal-type.", - ) - union_enum_values.append(enum_value[0]) - - # ensure there is a 1:1 bijection between the two - for m in member_enum_values: - assert m in union_enum_values, ( - f"{get_error_reporting_message()}", - f"There is no event-type registered for {m} in {_Event}.", - ) - union_enum_values.remove(m) - assert len(union_enum_values) == 0, ( - f"{get_error_reporting_message()}", - f"The following events have multiple event types defined in {_Event}: {union_enum_values}.", - ) - - -_check_event_type_consistency() - -Event = Annotated[_Event, Field(discriminator="event_type")] -"""Type of events, a discriminated union.""" - -# class TimerCreated(_BaseEvent[_EventType.TimerCreated]): -# event_type: Literal[_EventType.TimerCreated] = _EventType.TimerCreated -# timer_id: TimerId -# delay_seconds: float -# -# -# class TimerFired(_BaseEvent[_EventType.TimerFired]): -# event_type: Literal[_EventType.TimerFired] = _EventType.TimerFired -# timer_id: TimerId diff --git a/src/exo/shared/types/events/chunks.py b/src/exo/shared/types/events/chunks.py deleted file mode 100644 index 7a69ae5c..00000000 --- a/src/exo/shared/types/events/chunks.py +++ /dev/null @@ -1,71 +0,0 @@ -from enum import Enum -from typing import Annotated, Literal - -from pydantic import BaseModel, Field, TypeAdapter - -from exo.shared.openai_compat import FinishReason -from exo.shared.types.common import CommandId -from exo.shared.types.models import ModelId - - -class ChunkType(str, Enum): - token = "token" - image = "image" - - -class BaseChunk[ChunkTypeT: ChunkType](BaseModel): - chunk_type: ChunkTypeT - command_id: CommandId - idx: int - model: ModelId - - -class TokenChunk(BaseChunk[ChunkType.token]): - chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True) - text: str - token_id: int - finish_reason: FinishReason | None = None - - -class ImageChunk(BaseChunk[ChunkType.image]): - chunk_type: Literal[ChunkType.image] = Field(default=ChunkType.image, frozen=True) - data: bytes - - -GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")] -GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(GenerationChunk) - -## OpenAIResponse = ( -## ChatCompletion | ChatCompletionChunk -## ) ## Currently we only support chat completions - -# my_chunk: dict[str, Any] = TokenChunk( -# task_id=TaskId('nicerid'), -# idx=0, -# text='hello', -# token_id=12, -# chunk_type=ChunkType.token, -# model='llama-3.1', -# ).model_dump() -# print(my_chunk) -# restored = GenerationChunkTypeAdapter.validate_python(my_chunk) -# print(restored) - -#### OpenAI API Interfaces ### - -""" -def send_task(task: Any) -> AsyncGenerator[GenerationChunk]: - # This is the 'command' - turns the task into an event and pushes to the event queue. - # Tokens are then read off the event queue and pushed back to the api via an AsyncGenerator. - ... - -def parse_chunk_to_openai_response(chunk: GenerationChunk) -> OpenAIResponse: - ... - -async def handle_task(task: Any) -> AsyncGenerator[OpenAIResponse]: - ## In our api call function, we will do: - generator: AsyncGenerator[GenerationChunk] = send_task(task) - - async for chunk in generator: - yield parse_chunk_to_openai_response(chunk) -""" diff --git a/src/exo/shared/types/events/commands.py b/src/exo/shared/types/events/commands.py deleted file mode 100644 index 7469e1fa..00000000 --- a/src/exo/shared/types/events/commands.py +++ /dev/null @@ -1,61 +0,0 @@ -from enum import Enum -from typing import Annotated, Callable, Literal, Sequence - -from pydantic import BaseModel, Field, TypeAdapter - -from exo.shared.types.api import ChatCompletionTaskParams -from exo.shared.types.common import CommandId -from exo.shared.types.events import Event -from exo.shared.types.models import ModelMetadata -from exo.shared.types.state import State -from exo.shared.types.worker.common import InstanceId - - -# TODO: We need to have a distinction between create instance and spin up instance. -class CommandType(str, Enum): - CHAT_COMPLETION = "CHAT_COMPLETION" - CREATE_INSTANCE = "CREATE_INSTANCE" - DELETE_INSTANCE = "DELETE_INSTANCE" - TASK_FINISHED = "TASK_FINISHED" - - -class _BaseCommand[T: CommandType](BaseModel): - command_id: CommandId - command_type: T - - -class ChatCompletionCommand(_BaseCommand[CommandType.CHAT_COMPLETION]): - command_type: Literal[CommandType.CHAT_COMPLETION] = CommandType.CHAT_COMPLETION - request_params: ChatCompletionTaskParams - - -class CreateInstanceCommand(_BaseCommand[CommandType.CREATE_INSTANCE]): - command_type: Literal[CommandType.CREATE_INSTANCE] = CommandType.CREATE_INSTANCE - model_meta: ModelMetadata - instance_id: InstanceId - - -class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]): - command_type: Literal[CommandType.DELETE_INSTANCE] = CommandType.DELETE_INSTANCE - instance_id: InstanceId - - -class TaskFinishedCommand(_BaseCommand[CommandType.TASK_FINISHED]): - command_type: Literal[CommandType.TASK_FINISHED] = CommandType.TASK_FINISHED - - -Command = Annotated[ - ChatCompletionCommand - | CreateInstanceCommand - | DeleteInstanceCommand - | TaskFinishedCommand, - Field(discriminator="command_type"), -] - -CommandParser: TypeAdapter[Command] = TypeAdapter(Command) - - -type Decide = Callable[ - [State, Command], - Sequence[Event], -] diff --git a/src/exo/shared/types/events/components.py b/src/exo/shared/types/events/components.py deleted file mode 100644 index d0764b85..00000000 --- a/src/exo/shared/types/events/components.py +++ /dev/null @@ -1,36 +0,0 @@ -# components.py defines the small event functions, adapters etc. -# this name could probably be improved. - -from typing import ( - TYPE_CHECKING, -) - -if TYPE_CHECKING: - pass - -from typing import Callable - -from pydantic import BaseModel, Field, model_validator - -from exo.shared.types.common import NodeId -from exo.shared.types.state import State - -from ._events import Event - - -class EventFromEventLog[T: Event](BaseModel): - event: T - origin: NodeId - idx_in_log: int = Field(gt=0) - - @model_validator(mode="after") - def check_event_was_sent_by_correct_node( - self, - ) -> "EventFromEventLog[T]": - if self.event.check_event_was_sent_by_correct_node(self.origin): - return self - raise ValueError("Invalid Event: Origin ID Does Not Match") - - -type Apply = Callable[[State, Event], State] -type ApplyFromEventLog = Callable[[State, EventFromEventLog[Event]], State] diff --git a/src/exo/shared/types/graphs/pydantic.py b/src/exo/shared/types/graphs/pydantic.py deleted file mode 100644 index ce2afabb..00000000 --- a/src/exo/shared/types/graphs/pydantic.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Any, List - -from pydantic import BaseModel - - -class PydanticGraph(BaseModel): - vertices: List[Any] - edges: List[Any] diff --git a/src/exo/shared/types/memory.py b/src/exo/shared/types/memory.py new file mode 100644 index 00000000..21cd1534 --- /dev/null +++ b/src/exo/shared/types/memory.py @@ -0,0 +1,63 @@ +from math import ceil +from typing import Self + +from exo.utils.pydantic_ext import CamelCaseModel + + +class Memory(CamelCaseModel): + in_bytes: int = 0 + + @classmethod + def from_bytes(cls, val: int) -> Self: + """Construct a new Memory object from a number of bytes""" + return cls(in_bytes=val) + + @property + def in_kb(self) -> int: + """The approximate kilobytes this memory represents, rounded up. Setting this property rounds to the nearest byte.""" + return ceil(self.in_bytes / 1024) + + @in_kb.setter + def in_kb(self, val: int): + """Set this memorys value in kilobytes.""" + self.in_bytes = val * 1024 + + @classmethod + def from_kb(cls, val: int) -> Self: + """Construct a new Memory object from a number of kilobytes""" + return cls(in_bytes=val * 1024) + + @classmethod + def from_float_kb(cls, val: float) -> Self: + """Construct a new Memory object from a number of kilobytes, rounding where appropriate""" + return cls(in_bytes=round(val * 1024)) + + @property + def in_mb(self) -> float: + """The approximate megabytes this memory represents. Setting this property rounds to the nearest byte.""" + return self.in_bytes / (1024**2) + + @in_mb.setter + def in_mb(self, val: float): + """Set the megabytes for this memory, rounded to the nearest byte.""" + self.in_bytes = round(val * (1024**2)) + + @classmethod + def from_mb(cls, val: float) -> Self: + """Construct a new Memory object from a number of megabytes""" + return cls(in_bytes=round(val * (1024**2))) + + def __add__(self, other: "Memory") -> "Memory": + return Memory.from_bytes(self.in_bytes + other.in_bytes) + + def __lt__(self, other: Self) -> bool: + return self.in_bytes < other.in_bytes + + def __le__(self, other: Self) -> bool: + return self.in_bytes <= other.in_bytes + + def __gt__(self, other: Self) -> bool: + return self.in_bytes > other.in_bytes + + def __ge__(self, other: Self) -> bool: + return self.in_bytes >= other.in_bytes diff --git a/src/exo/shared/types/models.py b/src/exo/shared/types/models.py index 3d3d0456..eaff0d79 100644 --- a/src/exo/shared/types/models.py +++ b/src/exo/shared/types/models.py @@ -1,12 +1,16 @@ -from typing import Annotated, TypeAlias +from pydantic import PositiveInt -from pydantic import BaseModel, PositiveInt - -ModelId: TypeAlias = str +from exo.shared.types.common import ID +from exo.shared.types.memory import Memory +from exo.utils.pydantic_ext import CamelCaseModel -class ModelMetadata(BaseModel): +class ModelId(ID): + pass + + +class ModelMetadata(CamelCaseModel): model_id: ModelId pretty_name: str - storage_size_kilobytes: Annotated[int, PositiveInt] - n_layers: Annotated[int, PositiveInt] + storage_size: Memory + n_layers: PositiveInt diff --git a/src/exo/shared/types/multiaddr.py b/src/exo/shared/types/multiaddr.py index 23cf55ae..769e920d 100644 --- a/src/exo/shared/types/multiaddr.py +++ b/src/exo/shared/types/multiaddr.py @@ -1,8 +1,7 @@ import re -from ipaddress import IPv4Address, IPv6Address from typing import ClassVar -from pydantic import BaseModel, computed_field, field_serializer, field_validator +from pydantic import BaseModel, computed_field, field_validator class Multiaddr(BaseModel): @@ -33,32 +32,28 @@ class Multiaddr(BaseModel): raise ValueError(f"Invalid multiaddr format: {self.address}") @property - def ipv6_address(self) -> IPv6Address: + def ipv6_address(self) -> str: match = re.match(r"^/ip6/([0-9a-fA-F:]+)", self.address) if not match: raise ValueError( f"Invalid multiaddr format: {self.address}. Expected format like /ip6/::1/tcp/4001" ) - return IPv6Address(match.group(1)) + return match.group(1) @property - def ipv4_address(self) -> IPv4Address: + def ipv4_address(self) -> str: match = re.match(r"^/ip4/(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})", self.address) if not match: raise ValueError( f"Invalid multiaddr format: {self.address}. Expected format like /ip4/127.0.0.1/tcp/4001" ) - return IPv4Address(match.group(1)) + return match.group(1) @computed_field @property - def ip_address(self) -> IPv4Address | IPv6Address: + def ip_address(self) -> str: return self.ipv4_address if self.address_type == "ip4" else self.ipv6_address - @field_serializer("ip_address") - def serialize_ipv4_address(self, value: IPv4Address | IPv6Address) -> str: - return str(value) - @computed_field @property def port(self) -> int: diff --git a/src/exo/shared/types/profiling.py b/src/exo/shared/types/profiling.py index 304ac434..3ebb6798 100644 --- a/src/exo/shared/types/profiling.py +++ b/src/exo/shared/types/profiling.py @@ -1,14 +1,28 @@ -from pydantic import BaseModel, Field +from typing import Self + +from exo.shared.types.memory import Memory +from exo.utils.pydantic_ext import CamelCaseModel -class MemoryPerformanceProfile(BaseModel): - ram_total: int - ram_available: int - swap_total: int - swap_available: int +class MemoryPerformanceProfile(CamelCaseModel): + ram_total: Memory + ram_available: Memory + swap_total: Memory + swap_available: Memory + + @classmethod + def from_bytes( + cls, *, ram_total: int, ram_available: int, swap_total: int, swap_available: int + ) -> Self: + return cls( + ram_total=Memory.from_bytes(ram_total), + ram_available=Memory.from_bytes(ram_available), + swap_total=Memory.from_bytes(swap_total), + swap_available=Memory.from_bytes(swap_available), + ) -class SystemPerformanceProfile(BaseModel): +class SystemPerformanceProfile(CamelCaseModel): flops_fp16: float gpu_usage: float = 0.0 @@ -19,22 +33,22 @@ class SystemPerformanceProfile(BaseModel): ane_power: float = 0.0 -class NetworkInterfaceInfo(BaseModel): +class NetworkInterfaceInfo(CamelCaseModel): name: str ip_address: str type: str -class NodePerformanceProfile(BaseModel): +class NodePerformanceProfile(CamelCaseModel): model_id: str chip_id: str friendly_name: str memory: MemoryPerformanceProfile - network_interfaces: list[NetworkInterfaceInfo] = Field(default_factory=list) + network_interfaces: list[NetworkInterfaceInfo] = [] system: SystemPerformanceProfile -class ConnectionProfile(BaseModel): +class ConnectionProfile(CamelCaseModel): throughput: float latency: float jitter: float diff --git a/src/exo/shared/types/request.py b/src/exo/shared/types/request.py deleted file mode 100644 index d471be8b..00000000 --- a/src/exo/shared/types/request.py +++ /dev/null @@ -1,26 +0,0 @@ -from pydantic import BaseModel - -from exo.shared.types.api import ( - ChatCompletionTaskParams, - CreateInstanceTaskParams, - DeleteInstanceTaskParams, -) -from exo.shared.types.events import CommandId - - -class ChatCompletionCommand(BaseModel): - command_id: CommandId - command_params: ChatCompletionTaskParams - - -class CreateInstanceCommand(BaseModel): - command_id: CommandId - command_params: CreateInstanceTaskParams - - -class DeleteInstanceCommand(BaseModel): - command_id: CommandId - command_params: DeleteInstanceTaskParams - - -type Command = ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand diff --git a/src/exo/shared/types/state.py b/src/exo/shared/types/state.py index 368400df..e599b0af 100644 --- a/src/exo/shared/types/state.py +++ b/src/exo/shared/types/state.py @@ -3,11 +3,11 @@ from typing import Any, cast from pydantic import BaseModel, ConfigDict, Field, field_validator -from exo.shared.topology import Topology +from exo.shared.topology import Topology, TopologySnapshot from exo.shared.types.common import NodeId from exo.shared.types.profiling import NodePerformanceProfile from exo.shared.types.tasks import Task, TaskId -from exo.shared.types.worker.common import InstanceId, NodeStatus +from exo.shared.types.worker.common import InstanceId, WorkerStatus from exo.shared.types.worker.instances import Instance from exo.shared.types.worker.runners import RunnerId, RunnerStatus @@ -32,14 +32,14 @@ class State(BaseModel): Topology: _encode_topology, }, ) - node_status: Mapping[NodeId, NodeStatus] = {} + node_status: Mapping[NodeId, WorkerStatus] = {} instances: Mapping[InstanceId, Instance] = {} runners: Mapping[RunnerId, RunnerStatus] = {} tasks: Mapping[TaskId, Task] = {} node_profiles: Mapping[NodeId, NodePerformanceProfile] = {} topology: Topology = Topology() history: Sequence[Topology] = [] - last_event_applied_idx: int = Field(default=0, ge=0) + last_event_applied_idx: int = Field(default=-1, ge=-1) @field_validator("topology", mode="before") @classmethod @@ -53,12 +53,8 @@ class State(BaseModel): if isinstance(value, Topology): return value - # Lazy import to avoid circular dependencies. - from exo.shared.topology import Topology as _Topology - from exo.shared.topology import TopologySnapshot - if isinstance(value, Mapping): # likely a snapshot-dict coming from JSON snapshot = TopologySnapshot(**cast(dict[str, Any], value)) # type: ignore[arg-type] - return _Topology.from_snapshot(snapshot) + return Topology.from_snapshot(snapshot) raise TypeError("Invalid representation for Topology field in State") diff --git a/src/exo/shared/types/tasks.py b/src/exo/shared/types/tasks.py index 58f4b67f..200cef1c 100644 --- a/src/exo/shared/types/tasks.py +++ b/src/exo/shared/types/tasks.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal from pydantic import BaseModel, Field @@ -31,8 +31,8 @@ class ChatCompletionTask(BaseModel): task_status: TaskStatus task_params: ChatCompletionTaskParams - error_type: Optional[str] = Field(default=None) - error_message: Optional[str] = Field(default=None) + error_type: str | None = Field(default=None) + error_message: str | None = Field(default=None) Task = Annotated[ChatCompletionTask, Field(discriminator="task_type")] diff --git a/src/exo/shared/types/topology.py b/src/exo/shared/types/topology.py index 98f1d29c..1695a98b 100644 --- a/src/exo/shared/types/topology.py +++ b/src/exo/shared/types/topology.py @@ -1,31 +1,31 @@ -from typing import Iterable, Protocol - -from pydantic import BaseModel, ConfigDict - from exo.shared.types.common import NodeId from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile +from exo.utils.pydantic_ext import CamelCaseModel -class Connection(BaseModel): +class NodeInfo(CamelCaseModel): + node_id: NodeId + node_profile: NodePerformanceProfile | None = None + + +class Connection(CamelCaseModel): local_node_id: NodeId send_back_node_id: NodeId - local_multiaddr: Multiaddr - send_back_multiaddr: Multiaddr + send_back_multiaddr: Multiaddr | None connection_profile: ConnectionProfile | None = None - # required for Connection to be used as a key - model_config = ConfigDict(frozen=True, extra="forbid", strict=True) - def __hash__(self) -> int: - return hash( - ( - self.local_node_id, - self.send_back_node_id, - self.local_multiaddr.ip_address, - self.send_back_multiaddr.ip_address, + if self.send_back_multiaddr: + return hash( + ( + self.local_node_id, + self.send_back_node_id, + self.send_back_multiaddr.address, + ) ) - ) + else: + return hash((self.local_node_id, self.send_back_node_id)) def __eq__(self, other: object) -> bool: if not isinstance(other, Connection): @@ -33,48 +33,17 @@ class Connection(BaseModel): return ( self.local_node_id == other.local_node_id and self.send_back_node_id == other.send_back_node_id - and self.local_multiaddr.ip_address == other.local_multiaddr.ip_address - and self.send_back_multiaddr.ip_address - == other.send_back_multiaddr.ip_address + and self.send_back_multiaddr == other.send_back_multiaddr ) def is_thunderbolt(self) -> bool: - return str(self.local_multiaddr.ip_address).startswith("169.254") and str( - self.send_back_multiaddr.ip_address + return self.send_back_multiaddr is not None and str( + self.send_back_multiaddr.ipv4_address ).startswith("169.254") - -class Node(BaseModel): - node_id: NodeId - node_profile: NodePerformanceProfile | None = None - - -class TopologyProto(Protocol): - def add_node(self, node: Node) -> None: ... - - def add_connection( - self, - connection: Connection, - ) -> None: ... - - def list_nodes(self) -> Iterable[Node]: ... - - def list_connections(self) -> Iterable[Connection]: ... - - def update_node_profile( - self, node_id: NodeId, node_profile: NodePerformanceProfile - ) -> None: ... - - def update_connection_profile(self, connection: Connection) -> None: ... - - def remove_connection(self, connection: Connection) -> None: ... - - def remove_node(self, node_id: NodeId) -> None: ... - - def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None: ... - - def get_connection_profile( - self, connection: Connection - ) -> ConnectionProfile | None: ... - - def get_cycles(self) -> list[list[Node]]: ... + def reverse(self) -> "Connection": + return Connection( + local_node_id=self.send_back_node_id, + send_back_node_id=self.local_node_id, + send_back_multiaddr=None, + ) diff --git a/src/exo/shared/types/worker/commands_runner.py b/src/exo/shared/types/worker/commands_runner.py index 512e81cc..66696482 100644 --- a/src/exo/shared/types/worker/commands_runner.py +++ b/src/exo/shared/types/worker/commands_runner.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Annotated, Generic, Literal, TypeVar +from typing import Annotated, Literal from pydantic import BaseModel, Field, TypeAdapter @@ -8,19 +8,15 @@ from exo.shared.types.common import Host from exo.shared.types.tasks import ChatCompletionTaskParams from exo.shared.types.worker.shards import ShardMetadata + ## Messages passed TO the runner - - class MessageType(str, Enum): Setup = "setup" ChatTask = "chat_task" Exit = "exit" -MT = TypeVar(name="MT", bound=MessageType) - - -class BaseRunnerMessage(BaseModel, Generic[MT]): +class BaseRunnerMessage[MT: MessageType](BaseModel): pass @@ -47,9 +43,8 @@ RunnerMessage = Annotated[ ] RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage) + ## Responses passed FROM the runner - - class RunnerResponseType(str, Enum): InitializedResponse = "initialized_response" TokenizedResponse = "tokenized_response" @@ -59,10 +54,7 @@ class RunnerResponseType(str, Enum): ErrorResponse = "error_response" -RRT = TypeVar(name="RRT", bound=RunnerResponseType) - - -class BaseRunnerResponse(BaseModel, Generic[RRT]): +class BaseRunnerResponse[RRT: RunnerResponseType](BaseModel): pass diff --git a/src/exo/shared/types/worker/common.py b/src/exo/shared/types/worker/common.py index 37502167..55441dd9 100644 --- a/src/exo/shared/types/worker/common.py +++ b/src/exo/shared/types/worker/common.py @@ -11,7 +11,7 @@ class RunnerId(ID): pass -class NodeStatus(str, Enum): +class WorkerStatus(str, Enum): Idle = "Idle" Running = "Running" diff --git a/src/exo/shared/types/worker/communication.py b/src/exo/shared/types/worker/communication.py index 3afe8e69..0171acd6 100644 --- a/src/exo/shared/types/worker/communication.py +++ b/src/exo/shared/types/worker/communication.py @@ -38,7 +38,6 @@ def runner_write_error(error: Exception) -> None: logger.opt(exception=error).exception("Critical Runner error") - ## TODO: To make this cleaner, it seems like we should have only one writer. # This is fine in runner_supervisor but there's a risk in runner.py that we overlap things -# We can guarantee this by enqueueing messages and have a writing thread. \ No newline at end of file +# We can guarantee this by enqueueing messages and have a writing thread. diff --git a/src/exo/shared/types/worker/downloads.py b/src/exo/shared/types/worker/downloads.py index 54672205..aa5ee576 100644 --- a/src/exo/shared/types/worker/downloads.py +++ b/src/exo/shared/types/worker/downloads.py @@ -1,23 +1,20 @@ from enum import Enum from typing import ( Annotated, - Callable, Literal, - NewType, - Sequence, Union, ) -from pydantic import BaseModel, Field, PositiveInt +from pydantic import Field from exo.shared.types.common import NodeId -from exo.shared.types.models import ModelId -from exo.shared.types.worker.shards import ShardMetadata +from exo.shared.types.memory import Memory +from exo.utils.pydantic_ext import CamelCaseModel -class DownloadProgressData(BaseModel): - total_bytes: Annotated[int, PositiveInt] - downloaded_bytes: Annotated[int, PositiveInt] +class DownloadProgressData(CamelCaseModel): + total_bytes: Memory + downloaded_bytes: Memory class DownloadStatus(str, Enum): @@ -27,7 +24,7 @@ class DownloadStatus(str, Enum): Failed = "Failed" -class BaseDownloadProgress[DownloadStatusT: DownloadStatus](BaseModel): +class BaseDownloadProgress[DownloadStatusT: DownloadStatus](CamelCaseModel): node_id: NodeId download_status: DownloadStatusT @@ -67,18 +64,3 @@ DownloadProgress = Annotated[ ], Field(discriminator="download_status"), ] - - -BytesToDownload = NewType("BytesToDownload", int) -BytesDownloaded = NewType("BytesDownloaded", int) - -DownloadEffectHandler = Callable[ - [ModelId, DownloadStatus, BytesToDownload, BytesDownloaded], None -] - - -def download_shard( - model_id: ModelId, - shard_metadata: ShardMetadata, - effect_handlers: Sequence[DownloadEffectHandler], -) -> None: ... diff --git a/src/exo/shared/types/worker/runners.py b/src/exo/shared/types/worker/runners.py index 3bc70b5f..2a1e75da 100644 --- a/src/exo/shared/types/worker/runners.py +++ b/src/exo/shared/types/worker/runners.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from enum import Enum -from typing import Annotated, Generic, Literal, TypeVar +from typing import Annotated, Literal from pydantic import BaseModel, Field, TypeAdapter, model_validator @@ -20,11 +20,8 @@ class RunnerStatusType(str, Enum): Failed = "Failed" -RunnerStatusTypeT = TypeVar("RunnerStatusTypeT", bound=RunnerStatusType, covariant=True) - - -class BaseRunnerStatus(BaseModel, Generic[RunnerStatusTypeT]): - runner_status: RunnerStatusTypeT +class BaseRunnerStatus[T: RunnerStatusType](BaseModel): + runner_status: T class DownloadingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Downloading]): diff --git a/src/exo/shared/utils/pydantic_ext.py b/src/exo/shared/utils/pydantic_ext.py deleted file mode 100644 index e85591f7..00000000 --- a/src/exo/shared/utils/pydantic_ext.py +++ /dev/null @@ -1,52 +0,0 @@ -from pydantic import BaseModel -from pydantic.alias_generators import to_camel - - -class CamelCaseModel(BaseModel): - """ - A model whose fields are aliased to camel-case from snake-case. - """ - - class Config: - alias_generator = to_camel - allow_population_by_field_name = True - - -class Tagged[Tag: str, Content]( - CamelCaseModel -): # TODO: figure out how to make pydantic work with LiteralString - """ - Utility for helping with serializing unions as adjacently tagged with Pydantic. - - By default, Pydantic uses internally tagged union ser/de BUT to play nicely with - other cross-language ser/de tools, you need adjacently tagged unions, and Pydantic - doesn't support those out of the box. - - SEE: https://serde.rs/enum-representations.html#adjacently-tagged - - Example usage: - ```python - TaggedUnion = Annotated[Union[ - Tagged[Literal["Foo"], Foo], - Tagged[Literal["Bar"], Bar] - ], Field(discriminator="t")] - - Parser: TypeAdapter[TaggedUnion] = TypeAdapter(TaggedUnion) - - def validate_python(v: any) -> Foo | Bar: - v = Parser.validate_python(v) - match v.t: - case "Foo": return v.c - case "Bar": return v.c - ``` - """ - - t: Tag - """ - The tag corresponding to the type of the object in the union. - """ - - c: Content - """ - The actual content of the object of that type. - """ diff --git a/src/exo/shared/utils/__init__.py b/src/exo/utils/__init__.py similarity index 82% rename from src/exo/shared/utils/__init__.py rename to src/exo/utils/__init__.py index 87131484..53679125 100644 --- a/src/exo/shared/utils/__init__.py +++ b/src/exo/utils/__init__.py @@ -1,8 +1,6 @@ -from __future__ import annotations - from typing import Any, Type -from exo.shared.utils.phantom import PhantomData +from .phantom import PhantomData def ensure_type[T](obj: Any, expected_type: Type[T]) -> T: # type: ignore diff --git a/src/exo/utils/channels.py b/src/exo/utils/channels.py new file mode 100644 index 00000000..bc203e53 --- /dev/null +++ b/src/exo/utils/channels.py @@ -0,0 +1,56 @@ +from math import inf + +from anyio import ClosedResourceError, WouldBlock +from anyio.streams.memory import ( + MemoryObjectReceiveStream as AnyioReceiver, +) +from anyio.streams.memory import ( + MemoryObjectSendStream as AnyioSender, +) +from anyio.streams.memory import ( + MemoryObjectStreamState as AnyioState, +) + + +class Sender[T](AnyioSender[T]): + def clone_receiver(self) -> "Receiver[T]": + """Constructs a Sender using a Receivers shared state - similar to calling Receiver.clone() without needing the receiver""" + if self._closed: + raise ClosedResourceError + return Receiver(_state=self._state) + + +class Receiver[T](AnyioReceiver[T]): + def clone_sender(self) -> Sender[T]: + """Constructs a Sender using a Receivers shared state - similar to calling Sender.clone() without needing the sender""" + if self._closed: + raise ClosedResourceError + return Sender(_state=self._state) + + def collect(self) -> list[T]: + """Collect all currently available items from this receiver""" + out: list[T] = [] + while True: + try: + item = self.receive_nowait() + out.append(item) + except WouldBlock: + break + return out + + async def receive_at_least(self, n: int) -> list[T]: + out: list[T] = [] + out.append(await self.receive()) + out.extend(self.collect()) + while len(out) < n: + out.append(await self.receive()) + out.extend(self.collect()) + return out + + +class channel[T]: # noqa: N801 + def __new__(cls, max_buffer_size: float = inf) -> tuple[Sender[T], Receiver[T]]: + if max_buffer_size != inf and not isinstance(max_buffer_size, int): + raise ValueError("max_buffer_size must be either an integer or math.inf") + state = AnyioState[T](max_buffer_size) + return Sender(_state=state), Receiver(_state=state) diff --git a/src/exo/utils/event_buffer.py b/src/exo/utils/event_buffer.py new file mode 100644 index 00000000..eb1b4cf0 --- /dev/null +++ b/src/exo/utils/event_buffer.py @@ -0,0 +1,67 @@ +from loguru import logger + + +class OrderedBuffer[T]: + """ + A buffer that resequences events to ensure their ordering is preserved. + Currently this buffer doesn't raise any errors if an event is lost + This buffer is NOT thread safe, and is designed to only be polled from one + source at a time. + """ + + def __init__(self): + self.store: dict[int, T] = {} + self.next_idx_to_release: int = 0 + + def ingest(self, idx: int, t: T): + """Ingest a sequence into the buffer""" + logger.trace(f"Ingested event {t}") + if idx < self.next_idx_to_release: + return + if idx in self.store: + return + self.store[idx] = t + + def drain(self) -> list[T]: + """Drain all available events from the buffer""" + ret: list[T] = [] + while self.next_idx_to_release in self.store: + idx = self.next_idx_to_release + event = self.store.pop(idx) + ret.append(event) + self.next_idx_to_release += 1 + logger.trace(f"Releasing event {ret}") + return ret + + def drain_indexed(self) -> list[tuple[int, T]]: + """Drain all available events from the buffer""" + ret: list[tuple[int, T]] = [] + while self.next_idx_to_release in self.store: + idx = self.next_idx_to_release + event = self.store.pop(idx) + ret.append((idx, event)) + self.next_idx_to_release += 1 + logger.trace(f"Releasing event {ret}") + return ret + + +class MultiSourceBuffer[SourceId, T]: + """ + A buffer that resequences events to ensure their ordering is preserved. + Tracks events with multiple sources + """ + + def __init__(self): + self.stores: dict[SourceId, OrderedBuffer[T]] = {} + + def ingest(self, idx: int, t: T, source: SourceId): + if source not in self.stores: + self.stores[source] = OrderedBuffer() + buffer = self.stores[source] + buffer.ingest(idx, t) + + def drain(self) -> list[T]: + ret: list[T] = [] + for store in self.stores.values(): + ret.extend(store.drain()) + return ret diff --git a/src/exo/shared/utils/fs.py b/src/exo/utils/fs.py similarity index 96% rename from src/exo/shared/utils/fs.py rename to src/exo/utils/fs.py index a72a73ba..5419bde9 100644 --- a/src/exo/shared/utils/fs.py +++ b/src/exo/utils/fs.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import contextlib import os import pathlib diff --git a/src/exo/shared/utils/phantom.py b/src/exo/utils/phantom.py similarity index 76% rename from src/exo/shared/utils/phantom.py rename to src/exo/utils/phantom.py index 7311ea6e..4fe62afb 100644 --- a/src/exo/shared/utils/phantom.py +++ b/src/exo/utils/phantom.py @@ -1,13 +1,13 @@ from typing import Optional -class _PhantomData[T]: +class _PhantomData[*T]: """ Internal machinery of the phantom data - it stores nothing. """ -type PhantomData[T] = Optional[_PhantomData[T]] +type PhantomData[*T] = Optional[_PhantomData[*T]] """ Allows you to use generics in functions without storing anything of that generic type. Just use `None` and you'll be fine diff --git a/src/exo/utils/pydantic_ext.py b/src/exo/utils/pydantic_ext.py new file mode 100644 index 00000000..1bbedea2 --- /dev/null +++ b/src/exo/utils/pydantic_ext.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + + +class CamelCaseModel(BaseModel): + """ + A model whose fields are aliased to camel-case from snake-case. + """ + + model_config = ConfigDict( + alias_generator=to_camel, + validate_by_name=True, + extra="forbid", + # I want to reenable this ASAP, but it's causing an issue with TaskStatus + # strict=True, + ) diff --git a/src/exo/utils/pydantic_tagged.py b/src/exo/utils/pydantic_tagged.py new file mode 100644 index 00000000..3840e7dd --- /dev/null +++ b/src/exo/utils/pydantic_tagged.py @@ -0,0 +1,229 @@ +# pyright: reportAny=false, reportPrivateUsage=false, reportUnusedParameter=false, reportUnknownMemberType=false + +from collections.abc import Callable +from types import get_original_bases +from typing import ( + Any, + ClassVar, + Self, + Union, + cast, + get_args, + get_origin, +) + +import pydantic +from bidict import bidict +from pydantic import ( + BaseModel, + Field, + TypeAdapter, + model_serializer, + model_validator, +) +from pydantic_core import ( + PydanticCustomError, +) + + +def tagged_union[T: Tagged[Any]]( + type_map: dict[str, type], +) -> Callable[[type[T]], type[T]]: + def _decorator(cls: type[T]): + # validate and process the types + tagged_union_cls = _ensure_single_tagged_union_base(cls) + adapter_dict = _ensure_tagged_union_generic_is_union(tagged_union_cls) + type_bidict = _ensure_bijection_between_union_members_and_type_map( + set(adapter_dict.keys()), type_map + ) + + # inject the adapter and type class variables + cast(type[_TaggedImpl[Any]], cls)._type_bidict = type_bidict + cast(type[_TaggedImpl[Any]], cls)._adapter_dict = adapter_dict + + return cls + + return _decorator + + +class Tagged[C](BaseModel): + """ + Utility for helping with serializing unions as adjacently tagged with Pydantic. + + By default, Pydantic uses internally tagged union ser/de BUT to play nicely with + other cross-language ser/de tools, you need adjacently tagged unions, and Pydantic + doesn't support those out of the box. + SEE: https://serde.rs/enum-representations.html#adjacently-tagged + + This type is a Pydantic model in its own right and can be used on fields of other + Pydantic models. It must be used in combination with `tagged_union` decorator to work. + + Example usage: + ```python + FoobarUnion = Union[Foo, Bar, Baz] + + @tagged_union({ + "Foo": Foo, + "Bar": Bar, + "Baz": Baz, + }) + class TaggedFoobarUnion(Tagged[FoobarUnion]): ... + ``` + """ + + t: str = Field(frozen=True) + """ + The tag corresponding to the type of the object in the union. + """ + + c: C = Field(frozen=True) + """ + The actual content of the object of that type. + """ + + @classmethod + def from_(cls, c: C) -> Self: + t = cast(type[_TaggedImpl[C]], cls)._type_bidict.inv[type(c)] + return cls(t=t, c=c) + + @model_serializer + def _model_dump(self) -> dict[str, Any]: + cls = type(cast(_TaggedImpl[C], self)) + adapter = cls._adapter_dict[cls._type_bidict[self.t]] + return { + "t": self.t, + "c": adapter.dump_python(self.c), + } + + @model_validator(mode="before") + @classmethod + def _model_validate_before(cls, data: Any) -> Any: + cls = cast(type[_TaggedImpl[C]], cls) + + # check object shape & check "t" type is `str` + if not isinstance(data, dict): + raise PydanticCustomError( + "dict_type", "Wrong object type: expected a dictionary type" + ) + if "t" not in data or "c" not in data or len(data) != 2: # pyright: ignore[reportUnknownArgumentType] + raise ValueError( + "Wrong object shape: expected exactly {t: , c: }" + ) + if not isinstance(data["t"], str): + raise PydanticCustomError( + "string_type", 'Wrong field type: expected "t" to be `str`' + ) + + # grab tag & content keys + look up the type based on the tag + t = data["t"] + c = cast(Any, data["c"]) + ccls = cls._type_bidict.get(t) + if ccls is None: + raise PydanticCustomError( + "union_tag_not_found", + 'Wrong "t"-value: could not find tag within this discriminated union', + ) + cadapter = cls._adapter_dict[ccls] + + return { + "t": t, + "c": cadapter.validate_python(c), + } + + @model_validator(mode="after") + def _model_validate_after(self) -> Self: + cls = type(cast(_TaggedImpl[C], self)) + ccls = type(self.c) + + # sanity check for consistency + t = cls._type_bidict.inv.get(ccls) + if t is None: + raise ValueError( + 'Wrong "c"-value: could not find a tag corresponding to the type of this value' + ) + if t != self.t: + raise ValueError( + 'Wrong "t"-value: the provided tag for this content\'s type mismatches the configured tag' + ) + + return self + + +class _TaggedImpl[C](Tagged[C]): + _type_bidict: ClassVar[bidict[str, type]] + _adapter_dict: ClassVar[dict[type, TypeAdapter[Any]]] + + +def _ensure_single_tagged_union_base(cls: type[Any]) -> type[Any]: + bases = get_original_bases(cls) + + # count up all the bases (generic removed) and store last found one + cnt = 0 + last = None + for b in bases: + if pydantic._internal._generics.get_origin(b) == Tagged: # pyright: ignore[reportAttributeAccessIssue] + last = cast(type[Tagged[Any]], b) + cnt += 1 + + # sanity-check the bases + if last is None: + raise TypeError(f"Expected {Tagged!r} to be a base-class of {cls!r}") + if cnt > 1: + raise TypeError( + f"Expected only one {Tagged!r} base-class of {cls!r}, but got {cnt}" + ) + + return last + + +def _ensure_tagged_union_generic_is_union( + cls: type[Any], +) -> dict[type, TypeAdapter[Any]]: + # extract type of the generic argument + base_generics = cast(Any, pydantic._internal._generics.get_args(cls)) # pyright: ignore[reportAttributeAccessIssue] + assert len(base_generics) == 1 + union_cls = base_generics[0] + + # ensure the generic is a union => extract the members + union_origin = get_origin(union_cls) + if union_origin != Union: + raise TypeError( + f"Expected {Tagged!r} base-class to have its generic be a {Union!r}, but got {union_cls!r}" + ) + union_members = get_args(union_cls) + + # typecheck each of the members, creating a type<->adapter mapping + adapter_dict: dict[type, TypeAdapter[Any]] = {} + for m in union_members: + if not isinstance(m, type): + raise TypeError(f"Expected union member {m!r} to be a type") + adapter_dict[m] = TypeAdapter(m) + + return adapter_dict + + +def _ensure_bijection_between_union_members_and_type_map( + members: set[type], type_map: dict[str, type] +) -> bidict[str, type]: + mapped_members = set(type_map.values()) + + illegal_members = mapped_members - members + for m in illegal_members: + raise TypeError( + f"Expected type-map member {m!r} to be member of the union, but is not" + ) + missing_members = members - mapped_members + for m in missing_members: + raise TypeError( + f"Expected type-map to include a tag for member {m!r}, but is missing" + ) + assert mapped_members == members + + tag_sets = {m: {t for t in type_map if type_map[t] == m} for m in mapped_members} + for m, ts in tag_sets.items(): + if len(ts) > 1: + raise TypeError( + f"Expected a single tag per member of the union, but found {ts} for member {m!r}" + ) + + return bidict(type_map) diff --git a/src/exo/shared/utils/reactive.py b/src/exo/utils/reactive.py similarity index 100% rename from src/exo/shared/utils/reactive.py rename to src/exo/utils/reactive.py diff --git a/src/exo/utils/tests/test_tagged.py b/src/exo/utils/tests/test_tagged.py new file mode 100644 index 00000000..b138dcac --- /dev/null +++ b/src/exo/utils/tests/test_tagged.py @@ -0,0 +1,182 @@ +from typing import Union + +import pytest +from pydantic import BaseModel, TypeAdapter, ValidationError + +from exo.utils.pydantic_tagged import Tagged, tagged_union # ← CHANGE ME + + +def test_plain_union_prefers_first_member_when_shapes_are_identical(): + class Foo1(BaseModel): + x: int + + class Foo2(BaseModel): + x: int + + # Base Pydantic behavior: ambiguous dict goes to the first union member + ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2) + out = ta.validate_python({"x": 1}) + assert isinstance(out, Foo1), ( + "Base Pydantic should pick the first union member for identical shapes" + ) + + +def test_tagged_union_serializes_and_deserializes_two_identical_shapes_correctly(): + class Foo1(BaseModel): + x: int + + class Foo2(BaseModel): + x: int + + foos = Union[Foo1, Foo2] + + @tagged_union({"Foo1": Foo1, "Foo2": Foo2}) + class TaggedFoos(Tagged[foos]): + pass + + # ---- serialize (via custom model_serializer) ---- + t1 = TaggedFoos.from_(Foo1(x=1)) + assert t1.model_dump() == {"t": "Foo1", "c": {"x": 1}} + + t2 = TaggedFoos.from_(Foo2(x=2)) + assert t2.model_dump() == {"t": "Foo2", "c": {"x": 2}} + + # ---- deserialize (TypeAdapter -> model_validator(before)) ---- + ta = TypeAdapter(TaggedFoos) + + out1 = ta.validate_python({"t": "Foo1", "c": {"x": 10}}) + assert isinstance(out1.c, Foo1) and out1.c.x == 10 + + out2 = ta.validate_python({"t": "Foo2", "c": {"x": 20}}) + assert isinstance(out2.c, Foo2) and out2.c.x == 20 + + +def test_tagged_union_rejects_unknown_tag(): + class Foo1(BaseModel): + x: int + + class Foo2(BaseModel): + x: int + + foos = Union[Foo1, Foo2] + + @tagged_union({"Foo1": Foo1, "Foo2": Foo2}) + class TaggedFoos(Tagged[foos]): + pass + + ta = TypeAdapter(TaggedFoos) + with pytest.raises(ValidationError): + ta.validate_python({"t": "NotARealTag", "c": {"x": 0}}) + + +def test_multiple_tagged_classes_do_not_override_each_others_mappings(): + """ + Creating a *new* Tagged[T] class must not mutate the previously defined one. + This checks both the tag mapping and the per-class adapter dicts. + """ + + class Foo1(BaseModel): + x: int + + class Foo2(BaseModel): + x: int + + foos = Union[Foo1, Foo2] + + @tagged_union({"One": Foo1, "Two": Foo2}) + class TaggedEN(Tagged[foos]): + pass + + # Sanity: initial mapping/behavior + obj_en_1 = TaggedEN.from_(Foo1(x=5)) + assert obj_en_1.t == "One" + obj_en_2 = TaggedEN.from_(Foo2(x=6)) + assert obj_en_2.t == "Two" + + # Define a second, different mapping + @tagged_union({"Uno": Foo1, "Dos": Foo2}) + class TaggedES(Tagged[foos]): + pass + + # The two classes should have *independent* mappings + # (not the same object, and not equal content) + assert TaggedEN._type_bidict is not TaggedES._type_bidict # pyright: ignore + assert TaggedEN._type_bidict != TaggedES._type_bidict # pyright: ignore + + # Their adapters dicts should also be distinct objects + assert TaggedEN._adapter_dict is not TaggedES._adapter_dict # pyright: ignore + # And both should cover the same set of member types + assert set(TaggedEN._adapter_dict.keys()) == {Foo1, Foo2} # pyright: ignore + assert set(TaggedES._adapter_dict.keys()) == {Foo1, Foo2} # pyright: ignore + + # Re-check that EN behavior has NOT changed after ES was created + obj_en_1_again = TaggedEN.from_(Foo1(x=7)) + obj_en_2_again = TaggedEN.from_(Foo2(x=8)) + assert obj_en_1_again.t == "One" + assert obj_en_2_again.t == "Two" + + # ES behavior is per its *own* mapping + obj_es_1 = TaggedES.from_(Foo1(x=9)) + obj_es_2 = TaggedES.from_(Foo2(x=10)) + assert obj_es_1.t == "Uno" + assert obj_es_2.t == "Dos" + + # And deserialization respects each class's mapping independently + ta_en = TypeAdapter(TaggedEN) + ta_es = TypeAdapter(TaggedES) + + out_en = ta_en.validate_python({"t": "Two", "c": {"x": 123}}) + assert isinstance(out_en.c, Foo2) and out_en.c.x == 123 + + out_es = ta_es.validate_python({"t": "Dos", "c": {"x": 456}}) + assert isinstance(out_es.c, Foo2) and out_es.c.x == 456 + + +def test_two_tagged_classes_with_different_shapes_are_independent_and_not_cross_deserializable(): + class A1(BaseModel): + x: int + + class A2(BaseModel): + name: str + + union_a = Union[A1, A2] + + @tagged_union({"One": A1, "Two": A2}) + class TaggedA(Tagged[union_a]): + pass + + class B1(BaseModel): + name: str + + class B2(BaseModel): + active: bool + + union_b = Union[B1, B2] + + # Note: using the SAME tag strings intentionally to ensure mappings are per-class + @tagged_union({"One": B1, "Two": B2}) + class TaggedB(Tagged[union_b]): + pass + + # --- Per-class state must be independent --- + assert TaggedA._type_bidict is not TaggedB._type_bidict # pyright: ignore + assert TaggedA._adapter_dict is not TaggedB._adapter_dict # pyright: ignore + assert set(TaggedA._adapter_dict.keys()) == {A1, A2} # pyright: ignore + assert set(TaggedB._adapter_dict.keys()) == {B1, B2} # pyright: ignore + + # --- Round-trip for each class with overlapping tag strings --- + a_payload = TaggedA.from_(A1(x=123)).model_dump() + b_payload = TaggedB.from_(B1(name="neo")).model_dump() + + assert a_payload == {"t": "One", "c": {"x": 123}} + assert b_payload == {"t": "One", "c": {"name": "neo"}} + + # --- Cross-deserialization must fail despite overlapping "t" values --- + ta_a = TypeAdapter(TaggedA) + ta_b = TypeAdapter(TaggedB) + + with pytest.raises(ValidationError): + ta_a.validate_python(b_payload) # TaggedA expects {"x": ...} for tag "One" + + with pytest.raises(ValidationError): + ta_b.validate_python(a_payload) # TaggedB expects {"name": ...} for tag "One" diff --git a/src/exo/worker/common.py b/src/exo/worker/common.py index 143061a7..535fd8b3 100644 --- a/src/exo/worker/common.py +++ b/src/exo/worker/common.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -24,7 +23,7 @@ class AssignedRunner(BaseModel): status: RunnerStatus failures: list[tuple[float, Exception]] = [] - runner: Optional[RunnerSupervisor] # set if the runner is 'up' + runner: RunnerSupervisor | None # set if the runner is 'up' model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/exo/worker/download/download_utils.py b/src/exo/worker/download/download_utils.py index e2b4e8a2..b03e59eb 100644 --- a/src/exo/worker/download/download_utils.py +++ b/src/exo/worker/download/download_utils.py @@ -6,13 +6,13 @@ import time import traceback from datetime import timedelta from pathlib import Path -from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from urllib.parse import urljoin import aiofiles import aiofiles.os as aios import aiohttp -from pydantic import BaseModel, DirectoryPath, Field, TypeAdapter +from pydantic import BaseModel, DirectoryPath, Field, PositiveInt, TypeAdapter from exo.shared.constants import EXO_HOME from exo.shared.types.worker.shards import ShardMetadata @@ -25,7 +25,7 @@ from exo.worker.download.huggingface_utils import ( class ModelSafetensorsIndexMetadata(BaseModel): - total_size: Annotated[int, Field(ge=0)] + total_size: PositiveInt class ModelSafetensorsIndex(BaseModel): diff --git a/src/exo/worker/download/shard_downloader.py b/src/exo/worker/download/shard_downloader.py index ddb78915..30615222 100644 --- a/src/exo/worker/download/shard_downloader.py +++ b/src/exo/worker/download/shard_downloader.py @@ -3,7 +3,8 @@ from datetime import timedelta from pathlib import Path from typing import AsyncIterator, Callable -from exo.shared.types.models import ModelMetadata +from exo.shared.types.memory import Memory +from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.worker.shards import ( PartitionStrategy, PipelineShardMetadata, @@ -51,9 +52,9 @@ class ShardDownloader(ABC): repo_revision="noop", shard=PipelineShardMetadata( model_meta=ModelMetadata( - model_id="noop", + model_id=ModelId("noop"), pretty_name="noope", - storage_size_kilobytes=0, + storage_size=Memory.from_bytes(0), n_layers=1, ), partition_strategy=PartitionStrategy.pipeline, @@ -101,9 +102,9 @@ class NoopShardDownloader(ShardDownloader): repo_revision="noop", shard=PipelineShardMetadata( model_meta=ModelMetadata( - model_id="noop", + model_id=ModelId("noop"), pretty_name="noope", - storage_size_kilobytes=0, + storage_size=Memory.from_bytes(0), n_layers=1, ), partition_strategy=PartitionStrategy.pipeline, diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index edb58f2c..24c60323 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -1,109 +1,643 @@ import asyncio -from pathlib import Path +import time +from asyncio import Queue +from functools import partial +from random import random +from typing import AsyncGenerator, Optional +import anyio +from anyio import CancelScope, create_task_group +from anyio.abc import TaskGroup from loguru import logger +from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.shared.apply import apply -from exo.shared.constants import EXO_WORKER_LOG -from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from exo.shared.keypair import Keypair, get_node_id_keypair -from exo.shared.logging import logger_cleanup, logger_setup +from exo.shared.types.commands import ForwarderCommand, RequestEventLog, TaggedCommand from exo.shared.types.common import NodeId from exo.shared.types.events import ( + ChunkGenerated, + Event, + EventId, + ForwarderEvent, + IndexedEvent, + InstanceDeleted, NodePerformanceMeasured, + RunnerDeleted, + RunnerStatusUpdated, + TaggedEvent, + TaskFailed, + TaskStateUpdated, + TopologyEdgeCreated, + TopologyEdgeDeleted, ) +from exo.shared.types.memory import Memory +from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.profiling import NodePerformanceProfile -from exo.shared.types.worker.ops import ( - ExecuteTaskOp, - RunnerOp, +from exo.shared.types.state import State +from exo.shared.types.tasks import TaskId, TaskStatus +from exo.shared.types.topology import Connection +from exo.shared.types.worker.common import RunnerId +from exo.shared.types.worker.downloads import ( + DownloadCompleted, + DownloadOngoing, + DownloadPending, + DownloadProgressData, ) -from exo.worker.download.impl_shard_downloader import exo_shard_downloader +from exo.shared.types.worker.ops import ( + AssignRunnerOp, + ExecuteTaskOp, + RunnerDownOp, + RunnerFailedOp, + RunnerOp, + RunnerOpType, + RunnerUpOp, + UnassignRunnerOp, +) +from exo.shared.types.worker.runners import ( + DownloadingRunnerStatus, + FailedRunnerStatus, + InactiveRunnerStatus, + LoadedRunnerStatus, + RunningRunnerStatus, + StartingRunnerStatus, +) +from exo.shared.types.worker.shards import ShardMetadata +from exo.utils.channels import Receiver, Sender +from exo.utils.event_buffer import OrderedBuffer +from exo.worker.common import AssignedRunner +from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader from exo.worker.plan import plan -from exo.worker.utils.profile import start_polling_node_metrics -from exo.worker.worker import Worker +from exo.worker.runner.runner_supervisor import RunnerSupervisor +from exo.worker.utils import start_polling_node_metrics -async def run(worker: Worker): - assert worker.global_events is not None +class Worker: + def __init__( + self, + node_id: NodeId, + shard_downloader: ShardDownloader, + *, + initial_connection_messages: list[ConnectionMessage], + connection_message_receiver: Receiver[ConnectionMessage], + # Having written this pattern 3 times in the codebase: + # Should this be inherited??? Is this a real inheritance + # W???? + # Limitation: This SHOULD be a MasterForwarderEvent, but inheritance says no :| + global_event_receiver: Receiver[ForwarderEvent], + # Limitation: This SHOULD be a WorkerForwarderEvent, but inheritance says no :| + local_event_sender: Sender[ForwarderEvent], + # This is for requesting updates. It doesn't need to be a general command sender right now, + # but I think it's the correct way to be thinking about commands + command_sender: Sender[ForwarderCommand], + ): + self.node_id: NodeId = node_id + self.shard_downloader: ShardDownloader = shard_downloader + self.global_event_receiver = global_event_receiver + self.local_event_sender = local_event_sender + self.local_event_index = 0 + self.command_sender = command_sender + self.connection_message_receiver = connection_message_receiver + self.event_buffer = OrderedBuffer[Event]() + self._initial_connection_messages = initial_connection_messages + self.out_for_delivery: dict[EventId, ForwarderEvent] = {} - while True: - # 1. get latest events - events = await worker.global_events.get_events_since( - worker.state.last_event_applied_idx - ) + self.state: State = State() + self.assigned_runners: dict[RunnerId, AssignedRunner] = {} + self._tg: TaskGroup | None = None + self._nack_cancel_scope: CancelScope | None = None - # 2. for each event, apply it to the state and run sagas - for event_from_log in events: - worker.state = apply(worker.state, event_from_log) + async def run(self): + logger.info("Starting Worker") + # TODO: CLEANUP HEADER + async def resource_monitor_callback( + node_performance_profile: NodePerformanceProfile, + ) -> None: + await self.event_publisher( + NodePerformanceMeasured( + node_id=self.node_id, node_profile=node_performance_profile + ), + ) + + # END CLEANUP + + async with create_task_group() as tg: + self._tg = tg + tg.start_soon(start_polling_node_metrics, resource_monitor_callback) + tg.start_soon(self._connection_message_event_writer) + tg.start_soon(self._resend_out_for_delivery) + tg.start_soon(self._event_applier) + # TODO: This is a little gross, but not too bad + for msg in self._initial_connection_messages: + await self.event_publisher( + self._convert_connection_message_to_event(msg) + ) + self._initial_connection_messages = [] + + # Actual shutdown code - waits for all tasks to complete before executing. + self.local_event_sender.close() + self.command_sender.close() + for runner in self.assigned_runners.values(): + if runner.runner: + await runner.runner.astop() + + async def _event_applier(self): + with self.global_event_receiver as events: + async for event in events: + self.event_buffer.ingest(event.origin_idx, event.tagged_event.c) + event_id = event.tagged_event.c.event_id + if event_id in self.out_for_delivery: + del self.out_for_delivery[event_id] + + # 2. for each event, apply it to the state + indexed_events = self.event_buffer.drain_indexed() + if not indexed_events: + if ( + self._nack_cancel_scope is None + or self._nack_cancel_scope.cancel_called + ): + assert self._tg + self._tg.start_soon(self._nack_request) + elif self._nack_cancel_scope: + self._nack_cancel_scope.cancel() + + flag = False + for idx, event in indexed_events: + self.state = apply(self.state, IndexedEvent(idx=idx, event=event)) + if event_relevant_to_worker(event, self): + flag = True + + # 3. If we've found a "relevant" event, run a plan -> op -> execute cycle. + if flag: + await self.plan_step() + + async def plan_step(self): # 3. based on the updated state, we plan & execute an operation. op: RunnerOp | None = plan( - worker.assigned_runners, - worker.node_id, - worker.state.instances, - worker.state.runners, - worker.state.tasks, + self.assigned_runners, + self.node_id, + self.state.instances, + self.state.runners, + self.state.tasks, ) # run the op, synchronously blocking for now if op is not None: - logger.info(f"Executing op {str(op)[:500]}") - logger.bind(user_facing=True).debug(f"Worker executing op: {str(op)[:500]}") + logger.info(f"Executing op {str(op)[:100]}") + logger.debug(f"Worker executing op: {str(op)[:100]}") try: - async for event in worker.execute_op(op): - await worker.event_publisher(event) + async for event in self.execute_op(op): + await self.event_publisher(event) except Exception as e: if isinstance(op, ExecuteTaskOp): - generator = worker.fail_task( + generator = self.fail_task( e, runner_id=op.runner_id, task_id=op.task.task_id ) else: - generator = worker.fail_runner(e, runner_id=op.runner_id) + generator = self.fail_runner(e, runner_id=op.runner_id) async for event in generator: - await worker.event_publisher(event) + await self.event_publisher(event) - await asyncio.sleep(0.01) + def shutdown(self): + if self._tg: + self._tg.cancel_scope.cancel() - -async def async_main(): - node_id_keypair: Keypair = get_node_id_keypair() - node_id = NodeId(node_id_keypair.to_peer_id().to_base58()) - - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - shard_downloader = exo_shard_downloader() - - # TODO: add profiling etc to resource monitor - async def resource_monitor_callback( - node_performance_profile: NodePerformanceProfile, - ) -> None: - await event_log_manager.worker_events.append_events( - [ - NodePerformanceMeasured( - node_id=node_id, node_profile=node_performance_profile + async def _connection_message_event_writer(self): + with self.connection_message_receiver as connection_messages: + async for msg in connection_messages: + await self.event_publisher( + self._convert_connection_message_to_event(msg) ) - ], - origin=node_id, + + def _convert_connection_message_to_event(self, msg: ConnectionMessage): + match msg.connection_type: + case ConnectionMessageType.Connected: + return TopologyEdgeCreated( + edge=Connection( + local_node_id=self.node_id, + send_back_node_id=msg.node_id, + send_back_multiaddr=Multiaddr( + address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}" + ), + ) + ) + + case ConnectionMessageType.Disconnected: + return TopologyEdgeDeleted( + edge=Connection( + local_node_id=self.node_id, + send_back_node_id=msg.node_id, + send_back_multiaddr=Multiaddr( + address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}" + ), + ) + ) + + async def _nack_request(self) -> None: + # This function is started whenever we receive an event that is out of sequence. + # It is cancelled as soon as we receiver an event that is in sequence. + # Thus, if we don't make any progress within 1 + random() seconds, we request a copy of the event log + # This can be MASSIVELY tightened - just requesting a single event should be sufficient. + with CancelScope() as scope: + self._nack_cancel_scope = scope + try: + await anyio.sleep(1 + random()) + await self.command_sender.send( + ForwarderCommand( + origin=self.node_id, + tagged_command=TaggedCommand.from_( + RequestEventLog(since_idx=0) + ), + ) + ) + finally: + if self._nack_cancel_scope is scope: + self._nack_cancel_scope = None + + async def _resend_out_for_delivery(self) -> None: + # This can also be massively tightened, we should check events are at least a certain age before resending. + # Exponential backoff would also certainly help here. + while True: + await anyio.sleep(1 + random()) + for event in self.out_for_delivery.copy().values(): + await self.local_event_sender.send(event) + + ## Op Executors + + def _create_assigned_runner(self, op: AssignRunnerOp) -> AssignedRunner: + """Creates and stores a new AssignedRunner with initial downloading status.""" + assigned_runner = AssignedRunner( + runner_id=op.runner_id, + instance_id=op.instance_id, + shard_metadata=op.shard_metadata, + hosts=op.hosts, + status=DownloadingRunnerStatus( + download_progress=DownloadPending(node_id=self.node_id) + ), + runner=None, + ) + self.assigned_runners[op.runner_id] = assigned_runner + return assigned_runner + + async def _update_runner_status_to_completed_then_inactive( + self, assigned_runner: AssignedRunner + ) -> AsyncGenerator[Event, None]: + """Updates runner status from downloading to completed, then to inactive.""" + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadCompleted(node_id=self.node_id) + ) + yield assigned_runner.status_update_event() + + assigned_runner.status = InactiveRunnerStatus() + yield assigned_runner.status_update_event() + + async def _handle_already_downloaded_shard( + self, assigned_runner: AssignedRunner + ) -> AsyncGenerator[Event, None]: + """Handles the case where the shard is already downloaded.""" + async for event in self._update_runner_status_to_completed_then_inactive( + assigned_runner + ): + yield event + + async def _handle_shard_download_process( + self, + assigned_runner: AssignedRunner, + op: AssignRunnerOp, + initial_progress: RepoDownloadProgress, + ) -> AsyncGenerator[Event, None]: + """Manages the shard download process with progress tracking.""" + # Set initial ongoing status + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadOngoing( + node_id=self.node_id, + download_progress=DownloadProgressData( + total_bytes=Memory.from_bytes(initial_progress.total_bytes), + downloaded_bytes=Memory.from_bytes( + initial_progress.downloaded_bytes + ), + ), + ) + ) + yield assigned_runner.status_update_event() + + # Set up download progress tracking + download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue() + + def download_progress_callback( + shard: ShardMetadata, progress: RepoDownloadProgress + ) -> None: + download_progress_queue.put_nowait(progress) + + self.shard_downloader.on_progress(download_progress_callback) + download_task = asyncio.create_task( + self.shard_downloader.ensure_shard(op.shard_metadata) ) - asyncio.create_task(start_polling_node_metrics(callback=resource_monitor_callback)) + try: + async for event in self._monitor_download_progress( + assigned_runner, download_progress_queue + ): + yield event + finally: + if not download_task.done(): + download_task.cancel() - worker = Worker( - node_id, - shard_downloader, - event_log_manager.worker_events, - event_log_manager.global_events, - ) + async def _monitor_download_progress( + self, + assigned_runner: AssignedRunner, + download_progress_queue: asyncio.Queue[RepoDownloadProgress], + ) -> AsyncGenerator[Event, None]: + """Monitors download progress and yields status updates.""" + last_progress_time = 0.0 + throttle_interval_secs = 1.0 - await run(worker) - logger_cleanup() + while True: + progress: RepoDownloadProgress = await asyncio.wait_for( + download_progress_queue.get(), timeout=15 + ) + + if progress.status == "complete": + async for ( + event + ) in self._update_runner_status_to_completed_then_inactive( + assigned_runner + ): + yield event + break + elif progress.status == "in_progress": + if time.monotonic() - last_progress_time > throttle_interval_secs: + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadOngoing( + node_id=self.node_id, + download_progress=DownloadProgressData( + total_bytes=Memory.from_bytes(progress.total_bytes), + downloaded_bytes=Memory.from_bytes( + progress.downloaded_bytes + ), + ), + ) + ) + yield assigned_runner.status_update_event() + last_progress_time = time.monotonic() + + async def _execute_assign_op( + self, op: AssignRunnerOp + ) -> AsyncGenerator[Event, None]: + """ + A runner has been assigned. We need to also ensure that it's downloaded. + This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state. + """ + assigned_runner = self._create_assigned_runner(op) + initial_progress = ( + await self.shard_downloader.get_shard_download_status_for_shard( + op.shard_metadata + ) + ) + + if initial_progress.status == "complete": + async for event in self._handle_already_downloaded_shard(assigned_runner): + yield event + else: + async for event in self._handle_shard_download_process( + assigned_runner, op, initial_progress + ): + yield event + + async def _execute_unassign_op( + self, op: UnassignRunnerOp + ) -> AsyncGenerator[Event, None]: + if op.runner_id not in self.assigned_runners: + return + + # We can try to do a graceful shutdown of the runner. + runner: RunnerSupervisor | None = self.assigned_runners[op.runner_id].runner + if runner is not None: + await runner.astop() + + # This is all we really need: + del self.assigned_runners[op.runner_id] + yield RunnerDeleted(runner_id=op.runner_id) + + async def _execute_runner_up_op( + self, op: RunnerUpOp, initialize_timeout: Optional[float] = None + ) -> AsyncGenerator[Event, None]: + assigned_runner = self.assigned_runners[op.runner_id] + + # Emit "Starting" status right away so UI can show loading state + assigned_runner.status = StartingRunnerStatus() + yield assigned_runner.status_update_event() + + assigned_runner.runner = await RunnerSupervisor.create( + model_shard_meta=assigned_runner.shard_metadata, + hosts=assigned_runner.hosts, + initialize_timeout=initialize_timeout, + ) + + if assigned_runner.runner.runner_process.is_alive(): + assigned_runner.status = LoadedRunnerStatus() + else: + runner = assigned_runner.runner + logger.warning( + f"Runner status is not runner_process.is_alive(): exit code {runner.runner_process.exitcode}" + ) + + assigned_runner.status = FailedRunnerStatus() + yield self.assigned_runners[op.runner_id].status_update_event() + + async def _execute_runner_down_op( + self, op: RunnerDownOp + ) -> AsyncGenerator[Event, None]: + assigned_runner = self.assigned_runners[op.runner_id] + + if isinstance(assigned_runner.runner, RunnerSupervisor): + await assigned_runner.runner.astop() + + assigned_runner.runner = None + + assigned_runner.status = InactiveRunnerStatus() + yield assigned_runner.status_update_event() + return + + async def _execute_runner_failed_op( + self, op: RunnerFailedOp + ) -> AsyncGenerator[Event, None]: + """ + We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down. + """ + assigned_runner = self.assigned_runners[op.runner_id] + + if isinstance(assigned_runner.runner, RunnerSupervisor): + await ( + assigned_runner.runner.astop() + ) # astop the runner to ensure it clears out of memory. + + assigned_runner.status = FailedRunnerStatus() + yield self.assigned_runners[op.runner_id].status_update_event() + + async def _execute_task_op(self, op: ExecuteTaskOp) -> AsyncGenerator[Event, None]: + """ + This is the entry point for a chat completion starting. + While there is only one execute function, it will get called in different ways for runner 0 and runner [1, 2, 3, ...]. + Runners [1, 2, 3, ...] will run this method when a task is in 'pending' state. + Runner 0 will run this method when a task is in 'running' state. + TODO: How do we handle the logic of ensuring that n-1 nodes have started their execution before allowing the 0'th runner to start? + This is still a little unclear to me. + """ + assigned_runner = self.assigned_runners[op.runner_id] + + async def inner_execute(queue: asyncio.Queue[Event]) -> None: + async def running_callback(queue: asyncio.Queue[Event]) -> None: + # Called when the MLX process has been kicked off + assigned_runner.status = RunningRunnerStatus() + await queue.put(assigned_runner.status_update_event()) + + if assigned_runner.shard_metadata.device_rank == 0: + await queue.put( + TaskStateUpdated( + task_id=op.task.task_id, + task_status=TaskStatus.RUNNING, + ) + ) + + assert assigned_runner.runner is not None + assert assigned_runner.runner.runner_process.is_alive() + + async for chunk in assigned_runner.runner.stream_response( + task=op.task, request_started_callback=partial(running_callback, queue) + ): + if assigned_runner.shard_metadata.device_rank == 0: + await queue.put( + ChunkGenerated( + # TODO: at some point we will no longer have a bijection between task_id and row_id. + # So we probably want to store a mapping between these two in our Worker object. + command_id=chunk.command_id, + chunk=chunk, + ) + ) + + if op.task.task_id in self.state.tasks: + self.state.tasks[op.task.task_id].task_status = TaskStatus.COMPLETE + + if assigned_runner.shard_metadata.device_rank == 0: + # kind of hack - we don't want to wait for the round trip for this to complete + await queue.put( + TaskStateUpdated( + task_id=op.task.task_id, + task_status=TaskStatus.COMPLETE, + ) + ) + + # After a successful inference: + assigned_runner.status = LoadedRunnerStatus() + await queue.put(assigned_runner.status_update_event()) + + queue: Queue[Event] = asyncio.Queue() + task = asyncio.create_task(inner_execute(queue)) + + # TODO: Initial (prefil) timeout can be dynamic + # model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes + + try: + # Yield items from the queue + while True: + if task.done() and (exception := task.exception()): + raise exception + + try: + # Use a timeout to periodically check task status + item: Event = await asyncio.wait_for(queue.get(), timeout=0.01) + except asyncio.TimeoutError: + continue + + yield item + if isinstance(item, RunnerStatusUpdated) and isinstance( + item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus) + ): + if isinstance(item.runner_status, LoadedRunnerStatus): + assigned_runner.failures = [] + + break + finally: + # Ensure the task is cleaned up + try: + await asyncio.wait_for(task, timeout=5) + except asyncio.TimeoutError: + logger.warning( + "Timed out waiting for task cleanup after inference execution." + ) + + ## Operation Planner + + async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]: + ## It would be great if we can get rid of this async for ... yield pattern. + match op.op_type: + case RunnerOpType.ASSIGN_RUNNER: + event_generator = self._execute_assign_op(op) + case RunnerOpType.UNASSIGN_RUNNER: + event_generator = self._execute_unassign_op(op) + case RunnerOpType.RUNNER_UP: + event_generator = self._execute_runner_up_op(op) + case RunnerOpType.RUNNER_DOWN: + event_generator = self._execute_runner_down_op(op) + case RunnerOpType.RUNNER_FAILED: + event_generator = self._execute_runner_failed_op(op) + case RunnerOpType.CHAT_COMPLETION: + event_generator = self._execute_task_op(op) + + async for event in event_generator: + yield event + + async def fail_runner( + self, e: Exception, runner_id: RunnerId + ) -> AsyncGenerator[Event]: + if runner_id in self.assigned_runners: + assigned_runner = self.assigned_runners[runner_id] + + if assigned_runner.runner is not None: + await assigned_runner.runner.astop() + assigned_runner.runner = None + assigned_runner.status = FailedRunnerStatus(error_message=str(e)) + assigned_runner.failures.append((time.time(), e)) + + # Reset failure count back to 0 when succesful + if len(assigned_runner.failures) >= 3: + # Too many retries. We will emit a DeleteInstance + yield InstanceDeleted(instance_id=assigned_runner.instance_id) + + yield assigned_runner.status_update_event() + + async def fail_task( + self, e: Exception, runner_id: RunnerId, task_id: TaskId + ) -> AsyncGenerator[Event]: + if runner_id in self.assigned_runners: + yield TaskStateUpdated( + task_id=task_id, + task_status=TaskStatus.FAILED, + ) + + yield TaskFailed( + task_id=task_id, error_type=str(type(e)), error_message=str(e) + ) + + async for event in self.fail_runner(e, runner_id): + yield event + + async def event_publisher(self, event: Event) -> None: + fe = ForwarderEvent( + origin_idx=self.local_event_index, + origin=self.node_id, + tagged_event=TaggedEvent.from_(event), + ) + await self.local_event_sender.send(fe) + self.out_for_delivery[event.event_id] = fe + logger.debug( + f"Worker published event {self.local_event_index}: {str(event)[:100]}" + ) + self.local_event_index += 1 -def main(logfile: Path = EXO_WORKER_LOG, verbosity: int = 1): - logger_setup(logfile, verbosity) - asyncio.run(async_main()) - - -if __name__ == "__main__": - main() +def event_relevant_to_worker(event: Event, worker: Worker): + # TODO + return True diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index 250f8fd3..bf32f960 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -6,6 +6,7 @@ from exo.shared.types.events import ( ) from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.common import RunnerId +from exo.shared.types.worker.downloads import DownloadStatus from exo.shared.types.worker.instances import Instance, InstanceStatus from exo.shared.types.worker.ops import ( AssignRunnerOp, @@ -44,8 +45,12 @@ def unassign_runners( # If our instance is in 'downloading' or 'assigned' state, then we know the runner is stale. These are part of AssignRunnerOp and should be blocking. for assigned_runner_id in assigned_runners: - if assigned_runner_id in state_runners and isinstance( - state_runners[assigned_runner_id], DownloadingRunnerStatus + if ( + assigned_runner_id in state_runners + and isinstance(state_runners[assigned_runner_id], DownloadingRunnerStatus) + # Not sure about this type ignore, i don't think it should be necessary + and state_runners[assigned_runner_id].download_progress.download_status # type: ignore + != DownloadStatus.Completed ): return UnassignRunnerOp(runner_id=assigned_runner_id) @@ -196,11 +201,12 @@ def spin_up_runners( # Need to assert all other runners are ready before we can spin up. ready_to_spin = True for runner_id in instance.shard_assignments.node_to_runner.values(): - if ( - runner_id in state_runners - and state_runners[runner_id].runner_status - not in [RunnerStatusType.Inactive, RunnerStatusType.Starting] - ): + if runner_id in state_runners and state_runners[ + runner_id + ].runner_status not in [ + RunnerStatusType.Inactive, + RunnerStatusType.Starting, + ]: ready_to_spin = False if ready_to_spin: diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index 24d96bf3..b30271b5 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -13,6 +13,7 @@ def _redirect_stderr_to_file(path: str) -> None: # Rebind sys.stderr so Python's own writes go to the new fd as well (line-buffered) sys.stderr = os.fdopen(2, "w", buffering=1, closefd=False) + def entrypoint(raw_conn: Connection, err_path: str) -> None: """ Minimal entrypoint for the spawned child process. @@ -25,4 +26,5 @@ def entrypoint(raw_conn: Connection, err_path: str) -> None: # Import the heavy runner only after stderr is redirected from exo.worker.runner.runner import main + asyncio.run(main(raw_conn)) diff --git a/src/exo/worker/runner/generate.py b/src/exo/worker/runner/generate.py index b415fb54..5cfe1014 100644 --- a/src/exo/worker/runner/generate.py +++ b/src/exo/worker/runner/generate.py @@ -32,6 +32,7 @@ from exo.shared.types.worker.communication import ( generation_stream = mx.new_stream(mx.default_device()) + def generate_step( prompt: mx.array, model: Model, @@ -90,14 +91,13 @@ def generate_step( prompt_processed_tokens = 0 while total_prompt_tokens - prompt_processed_tokens > prefill_step_size: - runner_print(f'Prefilling {min(prefill_step_size, len(prompt))} tokens. Remaining tokens: {len(prompt)}. Peak memory: {mx.get_peak_memory() // 2**30} GB') - logits = model( - prompt[:prefill_step_size][None], - cache=prompt_cache + runner_print( + f"Prefilling {min(prefill_step_size, len(prompt))} tokens. Remaining tokens: {len(prompt)}. Peak memory: {mx.get_peak_memory() // 2**30} GB" ) + logits = model(prompt[:prefill_step_size][None], cache=prompt_cache) start_time = time.time() - mx.eval([c.state for c in prompt_cache] + [logits]) # type: ignore + mx.eval([c.state for c in prompt_cache] + [logits]) # type: ignore eval_time = time.time() - start_time prompt_processed_tokens += prefill_step_size @@ -109,34 +109,36 @@ def generate_step( prefill_step_size = broadcast_from_zero(prefill_step_size) prefill_step_size = max(1, prefill_step_size) + if prompt_processed_tokens > 0: + runner_print("finished prefil stage.") - runner_print('finished prefil.') y, logprobs = _step(input_tokens=prompt) - mx.async_eval(y, logprobs) # type: ignore + # TODO: Why on earth is this async_eval called twice? + # Also why is it async_eval not eval ? + mx.async_eval(y, logprobs) # type: ignore n = 0 next_y: array | None = None next_logprobs: array | None = None - mx.async_eval(y, logprobs) # type: ignore + mx.async_eval(y, logprobs) # type: ignore n = 0 while True: if n != max_tokens: assert y is not None next_y, next_logprobs = _step(y) - mx.async_eval(next_y, next_logprobs) # type: ignore + mx.async_eval(next_y, next_logprobs) # type: ignore if n == 0: - mx.eval(y) # type: ignore + mx.eval(y) # type: ignore if n == max_tokens: break - yield int(y.item()), logprobs # type: ignore + yield int(y.item()), logprobs # type: ignore if n % 256 == 0: mx.clear_cache() y, logprobs = next_y, next_logprobs n += 1 - def stream_generate( model: Model, tokenizer: TokenizerWrapper, @@ -147,21 +149,22 @@ def stream_generate( prompt_cache: Optional[list[KVCache]] = None, prefill_step_size: int = 2048, ) -> Generator[GenerationResponse, None, None]: - # Try to infer if special tokens are needed add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( tokenizer.bos_token ) - prompt_array: mx.array = mx.array(tokenizer.encode(prompt, add_special_tokens=add_special_tokens)) + prompt_array: mx.array = mx.array( + tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + ) if conn is not None: conn.send_sync(TokenizedResponse(prompt_tokens=len(prompt_array))) detokenizer = tokenizer.detokenizer token_generator: Generator[Tuple[int, array], None, None] = generate_step( - prompt_array, - model, - max_tokens=max_tokens, + prompt_array, + model, + max_tokens=max_tokens, sampler=sampler, prompt_cache=prompt_cache, prefill_step_size=prefill_step_size, @@ -190,6 +193,7 @@ def stream_generate( finish_reason="stop" if token in tokenizer.eos_token_ids else "length", ) + async def warmup_inference( mlx_executor: concurrent.futures.ThreadPoolExecutor, model: Model, @@ -222,7 +226,7 @@ async def warmup_inference( prompt=warmup_prompt, max_tokens=50, sampler=sampler, - conn=None + conn=None, ): tokens_generated += 1 @@ -231,6 +235,7 @@ async def warmup_inference( return tokens_generated + async def mlx_generate( mlx_executor: concurrent.futures.ThreadPoolExecutor, model: Model, @@ -272,9 +277,11 @@ async def mlx_generate( cache_future = loop.run_in_executor( mlx_executor, - lambda: asyncio.run(make_kv_cache( - model=model, - )) + lambda: asyncio.run( + make_kv_cache( + model=model, + ) + ), ) cache = await cache_future @@ -298,4 +305,4 @@ async def mlx_generate( yield item # Wait for the executor thread to complete - await future \ No newline at end of file + await future diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index 44874a0d..0de25749 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -25,14 +25,12 @@ from exo.shared.types.worker.communication import ( runner_write_error, ) from exo.shared.types.worker.shards import ShardMetadata -from exo.shared.utils import ensure_type +from exo.utils import ensure_type from exo.worker.runner.generate import mlx_generate, warmup_inference -from exo.worker.runner.utils import get_weights_size_kb +from exo.worker.runner.utils import get_weights_size -async def main( - raw_conn: Connection -): +async def main(raw_conn: Connection): conn = AsyncConnection[RunnerResponse, RunnerMessage](raw_conn) set_conn(conn) @@ -49,9 +47,9 @@ async def main( await asyncio.sleep(timeout) mlx_setup( - int(get_weights_size_kb(model_shard_meta) // 2**10), + int(get_weights_size(model_shard_meta).in_kb // 2**10), cache_frac_of_mrwss=0.8, - wired_frac_of_mrwss=0.8 + wired_frac_of_mrwss=0.8, ) setup_start_time = time.time() @@ -71,9 +69,7 @@ async def main( sampler=sampler, ) runner_print(f"Warmed up by generating {toks} tokens") - await conn.send( - InitializedResponse(time_taken=time.time() - setup_start_time) - ) + await conn.send(InitializedResponse(time_taken=time.time() - setup_start_time)) while True: message = await conn.recv() @@ -121,4 +117,3 @@ async def main( except Exception as e: runner_write_error(e) - diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index d9cc638a..9dcecf62 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -12,8 +12,11 @@ from typing import Any, AsyncGenerator, Callable, Coroutine, Optional import psutil from loguru import logger +from exo.shared.global_conn import ( + AsyncConnection, +) +from exo.shared.types.chunks import GenerationChunk, TokenChunk from exo.shared.types.common import CommandId, Host -from exo.shared.types.events.chunks import GenerationChunk, TokenChunk from exo.shared.types.tasks import ChatCompletionTaskParams, Task from exo.shared.types.worker.commands_runner import ( ChatTaskMessage, @@ -28,16 +31,13 @@ from exo.shared.types.worker.commands_runner import ( TokenizedResponse, ) from exo.shared.types.worker.common import RunnerError -from exo.shared.types.worker.communication import ( - AsyncConnection, -) from exo.shared.types.worker.shards import ShardMetadata from exo.worker.runner.bootstrap import entrypoint from exo.worker.runner.utils import ( get_init_timeout, get_prefil_timeout, get_token_generate_timeout, - get_weights_size_kb, + get_weights_size, ) @@ -74,16 +74,16 @@ class RunnerSupervisor: Create and initialize a RunnerSupervisor instance. The .create() classmethod pattern is used to ensure the constructor is asynchronous. """ - ctx = mp.get_context('spawn') + ctx = mp.get_context("spawn") parent_conn, child_conn = ctx.Pipe(duplex=True) - - with tempfile.NamedTemporaryFile(prefix="child_stderr_", suffix=".log", delete=False) as tmp: + + with tempfile.NamedTemporaryFile( + prefix="child_stderr_", suffix=".log", delete=False + ) as tmp: err_path = tmp.name runner_process = Process( - target=entrypoint, - args=(child_conn, err_path), - daemon=False + target=entrypoint, args=(child_conn, err_path), daemon=False ) runner_process.start() child_conn.close() @@ -96,7 +96,7 @@ class RunnerSupervisor: runner_process=runner_process, read_queue=read_queue, conn=parent_conn, - err_path=err_path + err_path=err_path, ) logger.info(f"Initializing mlx instance with {model_shard_meta=}") @@ -124,7 +124,7 @@ class RunnerSupervisor: if self.read_task.done(): e = self.read_task.exception() await self.astop() - if e is not None: + if e is not None: raise e else: return None @@ -149,10 +149,14 @@ class RunnerSupervisor: await self.read_task # Re-raises any exception from read_task # This should never get hit. - raise RunnerError("RunnerStopped", "Runner read loop terminated unexpectedly before any response.", "") - + raise RunnerError( + "RunnerStopped", + "Runner read loop terminated unexpectedly before any response.", + "", + ) + # if we haven't read from the queue, we have timed out. - await self.astop() # TODO: This could be handled by the called or _read_with_error_check - as we don't want a false Timeout to bring the whole runner down. + await self.astop() # TODO: This could be handled by the called or _read_with_error_check - as we don't want a false Timeout to bring the whole runner down. raise asyncio.TimeoutError() async def _read_coro(self): @@ -168,9 +172,11 @@ class RunnerSupervisor: match response: case PrintResponse(): # TODO: THIS IS A REALLY IMPORTANT LOG MESSAGE, AND SHOULD BE MADE PRETTIER - logger.bind(user_facing=True).info(f"{response.text}") + logger.info(f"{response.text}") case ErrorResponse(): - raise RunnerError(response.error_type, response.error_message, response.traceback) + raise RunnerError( + response.error_type, response.error_message, response.traceback + ) case _: await self.read_queue.put(response) @@ -205,7 +211,9 @@ class RunnerSupervisor: if request_started_callback is not None: await request_started_callback() - prefil_timeout = get_prefil_timeout(self.model_shard_meta, prompt_tokens=prompt_tokens) + prefil_timeout = get_prefil_timeout( + self.model_shard_meta, prompt_tokens=prompt_tokens + ) token_timeout = get_token_generate_timeout(self.model_shard_meta) timeout = prefil_timeout logger.bind(user_facing=True).info( @@ -237,7 +245,6 @@ class RunnerSupervisor: case _: raise ValueError(f"Unexpected response type found: {response}") - async def astop(self) -> None: # Cancel the stderr monitoring task async def await_task(task: asyncio.Task[Any]): @@ -255,7 +262,7 @@ class RunnerSupervisor: # Wait to make sure that the model has been unloaded from memory async def wait_for_memory_release() -> None: - required_memory_bytes = get_weights_size_kb(self.model_shard_meta) * 1024 + required_memory_bytes = get_weights_size(self.model_shard_meta).in_bytes start_time = asyncio.get_event_loop().time() while True: available_memory_bytes = psutil.virtual_memory().available @@ -315,12 +322,10 @@ class RunnerSupervisor: except Exception: cause = f"signal={sig}" - logger.bind(user_facing=True).error( - f"Runner terminated ({cause}).\n{captured}" - ) + logger.bind(user_facing=True).error(f"Runner terminated ({cause}).\n{captured}") return RunnerError( - error_type='RunnerCrash', + error_type="RunnerCrash", error_message=f"Runner terminated ({cause}).\n{captured}", traceback=traceback.format_exc(), ) diff --git a/src/exo/worker/runner/utils.py b/src/exo/worker/runner/utils.py index 1d68f377..3661ea2b 100644 --- a/src/exo/worker/runner/utils.py +++ b/src/exo/worker/runner/utils.py @@ -6,6 +6,7 @@ import psutil from loguru import logger from exo.shared.constants import LB_DISK_GBPS, LB_MEMBW_GBPS, LB_TFLOPS +from exo.shared.types.memory import Memory from exo.shared.types.worker.shards import ShardMetadata @@ -51,36 +52,36 @@ def get_runner_command() -> list[str]: return [python, "-m", "exo.worker.runner.runner"] -def get_weights_size_kb(model_shard_meta: ShardMetadata) -> float: - return ( +def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: + return Memory.from_float_kb( (model_shard_meta.end_layer - model_shard_meta.start_layer) / model_shard_meta.n_layers - * model_shard_meta.model_meta.storage_size_kilobytes + * model_shard_meta.model_meta.storage_size.in_kb ) def get_init_timeout(model_shard_meta: ShardMetadata) -> float: - weights_size_kb = get_weights_size_kb(model_shard_meta) + weights_size = get_weights_size(model_shard_meta) kbps_read = 1024 * 1024 * LB_DISK_GBPS / 3 - return weights_size_kb / kbps_read + 2.0 - + return weights_size.in_kb / kbps_read + 2.0 def _prefill_flops_for_shard(model_shard_meta: ShardMetadata, s: int) -> float: - p = get_weights_size_kb(model_shard_meta) * 1024 + p = get_weights_size(model_shard_meta).in_bytes flops = 2.0 * p * s # parameter-dependent GEMMs # flops += _attention_flops(meta, S) # optional S^2 term return flops + def get_prefil_timeout( model_shard_meta: ShardMetadata, prompt_tokens: int, *, effective_tflops: float = LB_TFLOPS, safety_mult: float = 1.6, - base_pad_s: float = 5.0 + base_pad_s: float = 5.0, ) -> float: """ Returns a conservative timeout (seconds) for the prefill stage. @@ -95,10 +96,9 @@ def get_prefil_timeout( return base_pad_s + safety_mult * time_seconds - def get_token_generate_timeout(model_shard_meta: ShardMetadata) -> float: - weights_size_kb = get_weights_size_kb(model_shard_meta) + weights_size = get_weights_size(model_shard_meta) kbps_read = 1024 * 1024 * LB_MEMBW_GBPS / 3 - return weights_size_kb / kbps_read + 2.0 + return weights_size.in_kb / kbps_read + 2.0 diff --git a/src/exo/worker/tests/conftest.py b/src/exo/worker/tests/conftest.py index 3f24ae5c..3c876418 100644 --- a/src/exo/worker/tests/conftest.py +++ b/src/exo/worker/tests/conftest.py @@ -1,5 +1,3 @@ -from ipaddress import IPv4Address -from logging import Logger, getLogger from typing import Callable, Optional import pytest @@ -18,40 +16,48 @@ from exo.shared.types.worker.common import InstanceId from exo.shared.types.worker.instances import Instance, InstanceStatus from exo.shared.types.worker.runners import RunnerId, ShardAssignments from exo.shared.types.worker.shards import PipelineShardMetadata +from exo.worker.main import Worker from exo.worker.tests.constants import ( COMMAND_1_ID, INSTANCE_1_ID, MODEL_A_ID, NODE_A, + NODE_B, RUNNER_1_ID, TASK_1_ID, ) +from .worker_management import ( + WorkerMailbox, + create_worker_and_mailbox, + create_worker_void_mailbox, + create_worker_with_old_mailbox, +) + @pytest.fixture -def user_message(): +def worker_void_mailbox() -> Worker: + return create_worker_void_mailbox(NODE_A) + + +@pytest.fixture +def worker_and_mailbox() -> tuple[Worker, WorkerMailbox]: + return create_worker_and_mailbox(NODE_A) + + +@pytest.fixture +def two_workers_with_shared_mailbox() -> tuple[Worker, Worker, WorkerMailbox]: + worker1, mailbox = create_worker_and_mailbox(NODE_A) + worker2 = create_worker_with_old_mailbox(NODE_B, mailbox) + return worker1, worker2, mailbox + + +@pytest.fixture +def user_message() -> str: """Override this fixture in tests to customize the message""" return "Hello, how are you?" -@pytest.fixture -def logger() -> Logger: - import logging - - logger = getLogger("test_logger") - logger.setLevel(logging.DEBUG) - - # Add console handler if none exists - if not logger.handlers: - handler = logging.StreamHandler() - handler.setLevel(logging.DEBUG) - formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - - return logger - - @pytest.fixture async def model_meta() -> ModelMetadata: return await get_model_meta("mlx-community/Llama-3.2-1B-Instruct-4bit") @@ -62,7 +68,7 @@ def hosts(): def _hosts(count: int, offset: int = 0) -> list[Host]: return [ Host( - ip=IPv4Address("127.0.0.1"), + ip="127.0.0.1", port=5000 + offset + i, ) for i in range(count) diff --git a/src/exo/worker/tests/constants.py b/src/exo/worker/tests/constants.py index 4de842f5..85e16ed6 100644 --- a/src/exo/worker/tests/constants.py +++ b/src/exo/worker/tests/constants.py @@ -16,8 +16,8 @@ RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333") INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222") INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444") -MODEL_A_ID: Final[ModelId] = "mlx-community/Llama-3.2-1B-Instruct-4bit" -MODEL_B_ID: Final[ModelId] = "mlx-community/TinyLlama-1.1B-Chat-v1.0" +MODEL_A_ID: Final[ModelId] = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") +MODEL_B_ID: Final[ModelId] = ModelId("mlx-community/TinyLlama-1.1B-Chat-v1.0") TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555") TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666") diff --git a/src/exo/worker/tests/test_handlers/conftest.py b/src/exo/worker/tests/test_handlers/conftest.py index b05fb23a..1cfd7a41 100644 --- a/src/exo/worker/tests/test_handlers/conftest.py +++ b/src/exo/worker/tests/test_handlers/conftest.py @@ -1,10 +1,7 @@ -from logging import Logger from typing import Callable import pytest -from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from exo.shared.logging import logger_test_install from exo.shared.types.common import NodeId from exo.shared.types.worker.common import InstanceId from exo.shared.types.worker.instances import Instance @@ -13,9 +10,8 @@ from exo.shared.types.worker.ops import ( RunnerUpOp, ) from exo.shared.types.worker.runners import RunnerId -from exo.worker.download.shard_downloader import NoopShardDownloader -from exo.worker.tests.constants import INSTANCE_1_ID, NODE_A, RUNNER_1_ID -from exo.worker.worker import Worker +from exo.worker.main import Worker +from exo.worker.tests.constants import INSTANCE_1_ID, RUNNER_1_ID @pytest.fixture @@ -23,27 +19,14 @@ def user_message(): return "What, according to Douglas Adams, is the meaning of life, the universe and everything?" -@pytest.fixture -async def worker(logger: Logger): - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - shard_downloader = NoopShardDownloader() - await event_log_manager.initialize() - - return Worker( - NODE_A, - shard_downloader, - worker_events=event_log_manager.global_events, - global_events=event_log_manager.global_events, - ) - - # TODO: instance_id and runner_id are selectable. @pytest.fixture async def worker_with_assigned_runner( - worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance] + worker_void_mailbox: Worker, + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], ): """Fixture that provides a worker with an already assigned runner.""" + worker = worker_void_mailbox instance_id = INSTANCE_1_ID runner_id = RUNNER_1_ID diff --git a/src/exo/worker/tests/test_handlers/test_handlers_happy.py b/src/exo/worker/tests/test_handlers/test_handlers_happy.py index 7accd983..86eb6ebf 100644 --- a/src/exo/worker/tests/test_handlers/test_handlers_happy.py +++ b/src/exo/worker/tests/test_handlers/test_handlers_happy.py @@ -2,6 +2,7 @@ from typing import Callable import pytest +from exo.shared.types.chunks import TokenChunk from exo.shared.types.common import NodeId from exo.shared.types.events import ( ChunkGenerated, @@ -9,7 +10,6 @@ from exo.shared.types.events import ( RunnerStatusUpdated, TaskStateUpdated, ) -from exo.shared.types.events.chunks import TokenChunk from exo.shared.types.tasks import ChatCompletionTask, TaskStatus from exo.shared.types.worker.common import RunnerId from exo.shared.types.worker.instances import Instance, InstanceId @@ -36,8 +36,10 @@ from exo.worker.tests.test_handlers.utils import read_events_op @pytest.mark.asyncio async def test_assign_op( - worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance] + worker_void_mailbox: Worker, + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], ): + worker = worker_void_mailbox instance_obj: Instance = instance(InstanceId(), worker.node_id, RUNNER_1_ID) assign_op = AssignRunnerOp( diff --git a/src/exo/worker/tests/test_integration/integration_utils.py b/src/exo/worker/tests/test_integration/integration_utils.py deleted file mode 100644 index 50154020..00000000 --- a/src/exo/worker/tests/test_integration/integration_utils.py +++ /dev/null @@ -1,145 +0,0 @@ -import asyncio -import contextlib -from contextlib import asynccontextmanager -from logging import Logger -from typing import Callable, Optional, Tuple, TypeVar - -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage -from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from exo.shared.logging import logger_test_install -from exo.shared.types.common import NodeId -from exo.shared.types.events import ChunkGenerated, TaskStateUpdated -from exo.shared.types.events.chunks import TokenChunk -from exo.shared.types.tasks import TaskId, TaskStatus -from exo.worker.download.shard_downloader import NoopShardDownloader -from exo.worker.main import run -from exo.worker.worker import Worker - - -@asynccontextmanager -async def worker_running(node_id: NodeId, logger: Logger): - """Context manager that provides a running worker and cleans up after.""" - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - - global_events = event_log_manager.global_events - await global_events.delete_all_events() - - shard_downloader = NoopShardDownloader() - worker = Worker( - node_id, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - - # Start the worker task - task = asyncio.create_task(run(worker)) - - try: - yield worker, global_events - finally: - # Cleanup - task.cancel() - with contextlib.suppress(asyncio.CancelledError, asyncio.TimeoutError): - await asyncio.wait_for(task, timeout=1.0) - - # Clean up any runners - for assigned_runner in worker.assigned_runners.values(): - if assigned_runner.runner: - await assigned_runner.runner.astop() - -async def read_streaming_response( - global_events: AsyncSQLiteEventStorage, filter_task: Optional[TaskId] = None -) -> Tuple[bool, bool, str, int]: - # Read off all events - these should be our GenerationChunk events - seen_task_started, seen_task_finished = 0, 0 - response_string = "" - finish_reason: str | None = None - token_count = 0 - - if not filter_task: - idx = await global_events.get_last_idx() - else: - found = False - idx = 0 - while not found: - events = await global_events.get_events_since(idx) - - for event in events: - if ( - isinstance(event.event, TaskStateUpdated) - and event.event.task_status == TaskStatus.RUNNING - and event.event.task_id == filter_task - ): - found = True - idx = event.idx_in_log - 1 - break - - print(f"START IDX {idx}") - - while not finish_reason: - events = await global_events.get_events_since(idx) - if len(events) == 0: - await asyncio.sleep(0.01) - continue - idx = events[-1].idx_in_log - - for wrapped_event in events: - event = wrapped_event.event - if isinstance(event, TaskStateUpdated): - if event.task_status == TaskStatus.RUNNING: - seen_task_started += 1 - if event.task_status == TaskStatus.COMPLETE: - seen_task_finished += 1 - - if isinstance(event, ChunkGenerated) and isinstance( - event.chunk, TokenChunk - ): - response_string += event.chunk.text - token_count += 1 - if event.chunk.finish_reason: - finish_reason = event.chunk.finish_reason - - await asyncio.sleep(0.2) - - print(f"event log: {await global_events.get_events_since(0)}") - - return seen_task_started == 1, seen_task_finished == 1, response_string, token_count - - -T = TypeVar("T") - - -async def until_event_with_timeout( - global_events: AsyncSQLiteEventStorage, - event_type: type[T], - multiplicity: int = 1, - condition: Callable[[T], bool] = lambda x: True, - timeout: float = 30.0, -) -> None: - idx = await global_events.get_last_idx() - times_seen = 0 - start_time = asyncio.get_event_loop().time() - - while True: - events = await global_events.get_events_since(idx) - if events: - for wrapped_event in events: - if isinstance(wrapped_event.event, event_type) and condition( - wrapped_event.event - ): - times_seen += 1 - if times_seen >= multiplicity: - return - idx = events[-1].idx_in_log - - current_time = asyncio.get_event_loop().time() - if current_time - start_time > timeout: - raise asyncio.TimeoutError( - f"Timeout waiting for {multiplicity} events of type {event_type.__name__} " - f"(found {times_seen} in {timeout}s)" - ) - - await asyncio.sleep(0.01) diff --git a/src/exo/worker/tests/test_integration/test_inference.py b/src/exo/worker/tests/test_integration/test_inference.py index 33a3c7ee..4118896f 100644 --- a/src/exo/worker/tests/test_integration/test_inference.py +++ b/src/exo/worker/tests/test_integration/test_inference.py @@ -1,11 +1,9 @@ import asyncio -from logging import Logger from typing import Callable import pytest +from anyio import create_task_group -from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from exo.shared.logging import logger_test_install from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams from exo.shared.types.common import CommandId, Host, NodeId from exo.shared.types.events import ( @@ -28,8 +26,7 @@ from exo.shared.types.worker.instances import ( ShardAssignments, ) from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.download.shard_downloader import NoopShardDownloader -from exo.worker.main import run +from exo.worker.main import Worker from exo.worker.tests.constants import ( INSTANCE_1_ID, MASTER_NODE_ID, @@ -39,11 +36,10 @@ from exo.worker.tests.constants import ( RUNNER_2_ID, TASK_1_ID, ) -from exo.worker.tests.test_integration.integration_utils import ( +from exo.worker.tests.worker_management import ( + WorkerMailbox, read_streaming_response, - worker_running, ) -from exo.worker.worker import Worker @pytest.fixture @@ -51,12 +47,15 @@ def user_message(): """Override this fixture in tests to customize the message""" return "What's the capital of Japan?" + async def test_runner_inference( instance: Callable[[InstanceId, NodeId, RunnerId], Instance], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, + worker_and_mailbox: tuple[Worker, WorkerMailbox], ): - async with worker_running(NODE_A, logger) as (_, global_events): + worker, global_events = worker_and_mailbox + async with create_task_group() as tg: + tg.start_soon(worker.run) instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE @@ -93,238 +92,173 @@ async def test_runner_inference( ) await asyncio.sleep(0.3) + worker.shutdown() + # TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly. async def test_2_runner_inference( - logger: Logger, pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], + two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], ): - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() + worker1, worker2, global_events = two_workers_with_shared_mailbox + async with create_task_group() as tg: + tg.start_soon(worker1.run) + tg.start_soon(worker2.run) + ## Instance + model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") - global_events = event_log_manager.global_events - await global_events.delete_all_events() + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1), + }, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, + ) - tasks: list[asyncio.Task[None]] = [] + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2), + ) - worker1 = Worker( - NODE_A, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker1))) + task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( + [ + InstanceCreated(instance=instance), + TaskCreated(task_id=task.task_id, task=task), + ], + origin=MASTER_NODE_ID, + ) - worker2 = Worker( - NODE_B, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker2))) + ( + seen_task_started, + seen_task_finished, + response_string, + _, + ) = await read_streaming_response(global_events) - ## Instance - model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") + assert seen_task_started + assert seen_task_finished + assert "tokyo" in response_string.lower() - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) + _ = global_events.collect() + await asyncio.sleep(1.0) + events = global_events.collect() + assert len(events) == 0 - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID, + ) - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert "tokyo" in response_string.lower() - - idx = await global_events.get_last_idx() - await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) - - for task in tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass # This is expected when we cancel a task - except Exception: - pass # Suppress any other exceptions during cleanup - - - for worker in (worker1, worker2): - for assigned_runner in worker.assigned_runners.values(): - if assigned_runner.runner: - await assigned_runner.runner.astop() + await asyncio.sleep(2.0) + worker1.shutdown() + worker2.shutdown() + # TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly. # TODO: Multi message parallel async def test_2_runner_multi_message( - logger: Logger, pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]], + two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], ): - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() + worker1, worker2, global_events = two_workers_with_shared_mailbox + async with create_task_group() as tg: + tg.start_soon(worker1.run) + tg.start_soon(worker2.run) - global_events = event_log_manager.global_events - await global_events.delete_all_events() + ## Instance + model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") - tasks: list[asyncio.Task[None]] = [] + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1), + }, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, + ) - worker1 = Worker( - NODE_A, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker1))) + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2), + ) - worker2 = Worker( - NODE_B, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker2))) + # Task - we have three messages here, which is what the task is about - ## Instance - model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") + completion_create_params = ChatCompletionTaskParams( + model="gpt-4", + messages=[ + ChatCompletionMessage( + role="user", content="What is the capital of France?" + ), + ChatCompletionMessage( + role="assistant", content="The capital of France is Paris." + ), + ChatCompletionMessage( + role="user", + content="Ok great. Now write me a haiku about what you can do there.", + ), + ], + stream=True, + ) - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) + task = ChatCompletionTask( + task_id=TASK_1_ID, + command_id=CommandId(), + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=completion_create_params, + ) - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) + await global_events.append_events( + [ + InstanceCreated(instance=instance), + TaskCreated(task_id=task.task_id, task=task), + ], + origin=MASTER_NODE_ID, + ) - # Task - we have three messages here, which is what the task is about + ( + seen_task_started, + seen_task_finished, + response_string, + _, + ) = await read_streaming_response(global_events) - completion_create_params = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content="What is the capital of France?" - ), - ChatCompletionMessage( - role="assistant", content="The capital of France is Paris." - ), - ChatCompletionMessage( - role="user", - content="Ok great. Now write me a haiku about what you can do there.", - ), - ], - stream=True, - ) + assert seen_task_started + assert seen_task_finished + assert any( + keyword in response_string.lower() + for keyword in ("kiss", "paris", "art", "love") + ) - task = ChatCompletionTask( - task_id=TASK_1_ID, - command_id=CommandId(), - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=completion_create_params, - ) + _ = global_events.collect() + await asyncio.sleep(1.0) + events = global_events.collect() + assert len(events) == 0 - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID, + ) - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert any( - keyword in response_string.lower() - for keyword in ("kiss", "paris", "art", "love") - ) - - idx = await global_events.get_last_idx() - await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - for task in tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass # This is expected when we cancel a task - except Exception: - pass # Suppress any other exceptions during cleanup - - for worker in (worker1, worker2): - for assigned_runner in worker.assigned_runners.values(): - if assigned_runner.runner: - await assigned_runner.runner.astop() - - await asyncio.sleep(2.0) + worker1.shutdown() + worker2.shutdown() + # TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly. diff --git a/src/exo/worker/tests/test_integration/test_inference_sad.py b/src/exo/worker/tests/test_integration/test_inference_sad.py index e88bba39..82916549 100644 --- a/src/exo/worker/tests/test_integration/test_inference_sad.py +++ b/src/exo/worker/tests/test_integration/test_inference_sad.py @@ -1,11 +1,13 @@ import asyncio from collections.abc import AsyncGenerator -from logging import Logger from types import CoroutineType from typing import Any, Callable import pytest from _pytest.monkeypatch import MonkeyPatch +from anyio import create_task_group + +from exo.shared.types.chunks import GenerationChunk, TokenChunk # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py from exo.shared.types.common import NodeId @@ -15,10 +17,9 @@ from exo.shared.types.events import ( InstanceDeleted, RunnerStatusUpdated, TaskCreated, + TaskFailed, TaskStateUpdated, ) -from exo.shared.types.events._events import TaskFailed -from exo.shared.types.events.chunks import GenerationChunk, TokenChunk from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.common import InstanceId, RunnerId from exo.shared.types.worker.instances import ( @@ -26,6 +27,7 @@ from exo.shared.types.worker.instances import ( InstanceStatus, ) from exo.shared.types.worker.runners import FailedRunnerStatus +from exo.worker.main import Worker from exo.worker.runner.runner_supervisor import RunnerSupervisor from exo.worker.tests.constants import ( INSTANCE_1_ID, @@ -34,9 +36,9 @@ from exo.worker.tests.constants import ( RUNNER_1_ID, TASK_1_ID, ) -from exo.worker.tests.test_integration.integration_utils import ( +from exo.worker.tests.worker_management import ( + WorkerMailbox, until_event_with_timeout, - worker_running, ) @@ -49,10 +51,12 @@ def user_message(): async def test_stream_response_failed_always( monkeypatch: MonkeyPatch, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - logger: Logger, chat_completion_task: Callable[[InstanceId, TaskId], Task], + worker_and_mailbox: tuple[Worker, WorkerMailbox], ) -> None: - async with worker_running(NODE_A, logger) as (_, global_events): + worker, global_events = worker_and_mailbox + async with create_task_group() as tg: + tg.start_soon(worker.run) instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE @@ -61,10 +65,8 @@ async def test_stream_response_failed_always( task: Task, request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, - ) -> AsyncGenerator[GenerationChunk]: + ) -> AsyncGenerator[GenerationChunk, None]: raise RuntimeError("Simulated stream response failure") - return - yield monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response) @@ -79,15 +81,15 @@ async def test_stream_response_failed_always( await until_event_with_timeout(global_events, InstanceDeleted, timeout=10.0) - events = await global_events.get_events_since(0) + events = global_events.collect() assert ( len( [ x for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) + if isinstance(x.tagged_event.c, RunnerStatusUpdated) + and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus) ] ) == 3 @@ -97,13 +99,13 @@ async def test_stream_response_failed_always( [ x for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.FAILED + if isinstance(x.tagged_event.c, TaskStateUpdated) + and x.tagged_event.c.task_status == TaskStatus.FAILED ] ) == 3 ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) + assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events]) await global_events.append_events( [ @@ -115,14 +117,16 @@ async def test_stream_response_failed_always( ) await asyncio.sleep(0.3) + worker.shutdown() async def test_stream_response_failed_once( monkeypatch: MonkeyPatch, - logger: Logger, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], chat_completion_task: Callable[[InstanceId, TaskId], Task], + worker_and_mailbox: tuple[Worker, WorkerMailbox], ): + worker, global_events = worker_and_mailbox failed_already = False original_stream_response = RunnerSupervisor.stream_response @@ -145,7 +149,8 @@ async def test_stream_response_failed_once( monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response) - async with worker_running(NODE_A, logger) as (worker, global_events): + async with create_task_group() as tg: + tg.start_soon(worker.run) instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE @@ -175,14 +180,14 @@ async def test_stream_response_failed_once( assert worker.state.tasks[TASK_1_ID].error_type is None assert worker.state.tasks[TASK_1_ID].error_message is None - events = await global_events.get_events_since(0) + events = global_events.collect() assert ( len( [ x for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) + if isinstance(x.tagged_event.c, RunnerStatusUpdated) + and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus) ] ) == 1 @@ -192,19 +197,19 @@ async def test_stream_response_failed_once( [ x for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.FAILED + if isinstance(x.tagged_event.c, TaskStateUpdated) + and x.tagged_event.c.task_status == TaskStatus.FAILED ] ) == 1 ) response_string = "" - events = await global_events.get_events_since(0) + events = global_events.collect() seen_task_started, seen_task_finished = False, False for wrapped_event in events: - event = wrapped_event.event + event = wrapped_event.tagged_event.c if isinstance(event, TaskStateUpdated): if event.task_status == TaskStatus.RUNNING: seen_task_started = True @@ -229,14 +234,17 @@ async def test_stream_response_failed_once( ) await asyncio.sleep(0.3) + worker.shutdown() async def test_stream_response_timeout( instance: Callable[[InstanceId, NodeId, RunnerId], Instance], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, + worker_and_mailbox: tuple[Worker, WorkerMailbox], ): - async with worker_running(NODE_A, logger) as (_, global_events): + worker, global_events = worker_and_mailbox + async with create_task_group() as tg: + tg.start_soon(worker.run) instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE @@ -250,17 +258,19 @@ async def test_stream_response_timeout( origin=MASTER_NODE_ID, ) - await until_event_with_timeout(global_events, TaskFailed, multiplicity=3, timeout=30.0) + await until_event_with_timeout( + global_events, TaskFailed, multiplicity=3, timeout=30.0 + ) - events = await global_events.get_events_since(0) + events = global_events.collect() print(events) assert ( len( [ x for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) + if isinstance(x.tagged_event.c, RunnerStatusUpdated) + and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus) ] ) == 3 @@ -270,8 +280,8 @@ async def test_stream_response_timeout( [ x for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.FAILED + if isinstance(x.tagged_event.c, TaskStateUpdated) + and x.tagged_event.c.task_status == TaskStatus.FAILED ] ) == 3 @@ -281,8 +291,8 @@ async def test_stream_response_timeout( [ x for x in events - if isinstance(x.event, TaskFailed) - and "timeouterror" in x.event.error_type.lower() + if isinstance(x.tagged_event.c, TaskFailed) + and "timeouterror" in x.tagged_event.c.error_type.lower() ] ) == 3 @@ -298,3 +308,4 @@ async def test_stream_response_timeout( ) await asyncio.sleep(0.3) + worker.shutdown() diff --git a/src/exo/worker/tests/test_integration/test_instantiation.py b/src/exo/worker/tests/test_integration/test_instantiation.py index 673afd92..fdba8ba1 100644 --- a/src/exo/worker/tests/test_integration/test_instantiation.py +++ b/src/exo/worker/tests/test_integration/test_instantiation.py @@ -1,6 +1,7 @@ -from logging import Logger from typing import Callable +from anyio import create_task_group + # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py from exo.shared.types.common import NodeId @@ -18,26 +19,28 @@ from exo.shared.types.worker.instances import ( from exo.shared.types.worker.runners import ( FailedRunnerStatus, ) +from exo.worker.main import Worker from exo.worker.tests.constants import ( INSTANCE_1_ID, MASTER_NODE_ID, NODE_A, RUNNER_1_ID, ) -from exo.worker.tests.test_integration.integration_utils import ( - until_event_with_timeout, - worker_running, -) +from exo.worker.tests.worker_management import WorkerMailbox, until_event_with_timeout async def test_runner_spinup_timeout( instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - logger: Logger, + worker_and_mailbox: tuple[Worker, WorkerMailbox], ): - async with worker_running(NODE_A, logger) as (_, global_events): + worker, global_events = worker_and_mailbox + async with create_task_group() as tg: + tg.start_soon(worker.run) instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE - instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10 + instance_value.shard_assignments.runner_to_shard[ + RUNNER_1_ID + ].should_timeout = 10 await global_events.append_events( [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID @@ -51,17 +54,18 @@ async def test_runner_spinup_timeout( ) # Ensure the correct events have been emitted - events = await global_events.get_events_since(0) + events = global_events.collect() assert ( len( [ x for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) + if isinstance(x.tagged_event.c, RunnerStatusUpdated) + and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus) ] ) == 3 ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) \ No newline at end of file + assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events]) + worker.shutdown() diff --git a/src/exo/worker/tests/test_integration/test_instantiation_sad.py b/src/exo/worker/tests/test_integration/test_instantiation_sad.py index ed4b59e4..f96c227f 100644 --- a/src/exo/worker/tests/test_integration/test_instantiation_sad.py +++ b/src/exo/worker/tests/test_integration/test_instantiation_sad.py @@ -1,7 +1,8 @@ import asyncio -from logging import Logger from typing import Callable +from anyio import create_task_group + # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py from exo.shared.types.common import NodeId @@ -19,23 +20,23 @@ from exo.shared.types.worker.instances import ( from exo.shared.types.worker.runners import ( FailedRunnerStatus, ) +from exo.worker.main import Worker from exo.worker.tests.constants import ( INSTANCE_1_ID, MASTER_NODE_ID, NODE_A, RUNNER_1_ID, ) -from exo.worker.tests.test_integration.integration_utils import ( - until_event_with_timeout, - worker_running, -) +from exo.worker.tests.worker_management import WorkerMailbox, until_event_with_timeout async def test_runner_spinup_exception( instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - logger: Logger, + worker_and_mailbox: tuple[Worker, WorkerMailbox], ): - async with worker_running(NODE_A, logger) as (_, global_events): + worker, global_events = worker_and_mailbox + async with create_task_group() as tg: + tg.start_soon(worker.run) instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE instance_value.shard_assignments.runner_to_shard[ @@ -49,30 +50,35 @@ async def test_runner_spinup_exception( await asyncio.sleep(10.0) # Ensure the correct events have been emitted - events = await global_events.get_events_since(0) + events = global_events.collect() assert ( len( [ x for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) + if isinstance(x.tagged_event.c, RunnerStatusUpdated) + and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus) ] ) == 3 ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) + assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events]) + worker.shutdown() async def test_runner_spinup_timeout( instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - logger: Logger, + worker_and_mailbox: tuple[Worker, WorkerMailbox], ): - async with worker_running(NODE_A, logger) as (_, global_events): + worker, global_events = worker_and_mailbox + async with create_task_group() as tg: + tg.start_soon(worker.run) instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE - instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10 + instance_value.shard_assignments.runner_to_shard[ + RUNNER_1_ID + ].should_timeout = 10 await global_events.append_events( [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID @@ -86,17 +92,18 @@ async def test_runner_spinup_timeout( ) # Ensure the correct events have been emitted - events = await global_events.get_events_since(0) + events = global_events.collect() assert ( len( [ x for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) + if isinstance(x.tagged_event.c, RunnerStatusUpdated) + and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus) ] ) == 3 ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) + assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events]) + worker.shutdown() diff --git a/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py b/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py index 2cc9f7da..9ce8746f 100644 --- a/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py +++ b/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py @@ -1,13 +1,11 @@ import asyncio import os import time -from logging import Logger from typing import Callable import pytest +from anyio import create_task_group -from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from exo.shared.logging import logger_test_install from exo.shared.models.model_meta import get_model_meta from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams from exo.shared.types.common import Host @@ -34,8 +32,7 @@ from exo.shared.types.worker.instances import ( ) from exo.shared.types.worker.runners import LoadedRunnerStatus from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.download.shard_downloader import NoopShardDownloader -from exo.worker.main import run +from exo.worker.main import Worker from exo.worker.tests.constants import ( COMMAND_1_ID, COMMAND_2_ID, @@ -48,16 +45,16 @@ from exo.worker.tests.constants import ( TASK_1_ID, TASK_2_ID, ) -from exo.worker.tests.test_integration.integration_utils import ( +from exo.worker.tests.worker_management import ( + WorkerMailbox, read_streaming_response, until_event_with_timeout, - worker_running, ) -from exo.worker.worker import Worker MODEL_ID = "mlx-community/Llama-3.3-70B-Instruct-4bit" SKIP = True + @pytest.fixture async def model_meta() -> ModelMetadata: return await get_model_meta(MODEL_ID) @@ -73,30 +70,32 @@ def _get_model_size_gb(path: str) -> float: total_size += os.path.getsize(filepath) return total_size / (1024**3) # Convert bytes to GB + skip = SKIP or not ( - os.path.exists( - os.path.expanduser( - "~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/" - ) - ) - and _get_model_size_gb( - os.path.expanduser( - "~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/" - ) - ) - > 30 + os.path.exists( + os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/") + ) + and _get_model_size_gb( + os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/") + ) + > 30 ) + @pytest.mark.skipif( skip, reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", ) async def test_ttft( - logger: Logger, pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]], + worker_and_mailbox: tuple[Worker, WorkerMailbox], ): - async with worker_running(NODE_A, logger) as (_, global_events): + from loguru import logger + + worker, global_events = worker_and_mailbox + async with create_task_group() as tg: + tg.start_soon(worker.run) ## Instance model_id = ModelId(MODEL_ID) @@ -146,8 +145,8 @@ async def test_ttft( ) print("Starting first inference...") - # Record the current event index before creating the task - idx_before_task1 = await global_events.get_last_idx() + # Clean out the current global events + _ = global_events.collect() task_created_time_1 = time.time() await global_events.append_events( @@ -158,21 +157,19 @@ async def test_ttft( first_chunk_seen_1 = False time_to_first_token_1: None | float = None while not first_chunk_seen_1: - events = await global_events.get_events_since(idx_before_task1) - for wrapped_event in events: - if isinstance(wrapped_event.event, ChunkGenerated) and hasattr( - wrapped_event.event, "chunk" - ): - first_chunk_time_1 = time.time() - time_to_first_token_1 = first_chunk_time_1 - task_created_time_1 - first_chunk_seen_1 = True - break - if not first_chunk_seen_1: - await asyncio.sleep(0.01) + event = (await global_events.receive()).tagged_event.c + if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"): + first_chunk_time_1 = time.time() + time_to_first_token_1 = first_chunk_time_1 - task_created_time_1 + first_chunk_seen_1 = True + break - _, seen_task_finished_1, response_string_1, token_count_1 = await read_streaming_response( - global_events - ) + ( + _, + seen_task_finished_1, + response_string_1, + token_count_1, + ) = await read_streaming_response(global_events) total_time_1 = time.time() - task_created_time_1 assert seen_task_finished_1 @@ -201,8 +198,9 @@ async def test_ttft( ) print("Starting second inference...") + # Clean out the current global events # Record the current event index before creating the second task - idx_before_task2 = await global_events.get_last_idx() + _ = global_events.collect() task_created_time_2 = time.time() await global_events.append_events( @@ -213,21 +211,19 @@ async def test_ttft( first_chunk_seen_2 = False time_to_first_token_2: float | None = None while not first_chunk_seen_2: - events = await global_events.get_events_since(idx_before_task2) - for wrapped_event in events: - if isinstance(wrapped_event.event, ChunkGenerated) and hasattr( - wrapped_event.event, "chunk" - ): - first_chunk_time_2 = time.time() - time_to_first_token_2 = first_chunk_time_2 - task_created_time_2 - first_chunk_seen_2 = True - break - if not first_chunk_seen_2: - await asyncio.sleep(0.01) + event = (await global_events.receive()).tagged_event.c + if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"): + first_chunk_time_2 = time.time() + time_to_first_token_2 = first_chunk_time_2 - task_created_time_2 + first_chunk_seen_2 = True + break - _, seen_task_finished_2, response_string_2, token_count_2 = await read_streaming_response( - global_events, filter_task=TASK_2_ID - ) + ( + _, + seen_task_finished_2, + response_string_2, + token_count_2, + ) = await read_streaming_response(global_events, filter_task=TASK_2_ID) total_time_2 = time.time() - task_created_time_2 assert seen_task_finished_2 @@ -239,15 +235,23 @@ async def test_ttft( prompt_tokens = 45 # Prefill TPS = prompt tokens / time to first token - prefill_tps_1 = prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0 - prefill_tps_2 = prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0 + prefill_tps_1 = ( + prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0 + ) + prefill_tps_2 = ( + prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0 + ) # Generation TPS = generated tokens / generation time # Generation time = total time - time to first token generation_time_1 = total_time_1 - time_to_first_token_1 generation_time_2 = total_time_2 - time_to_first_token_2 - generation_tps_1 = token_count_1 / generation_time_1 if generation_time_1 > 0 else 0 - generation_tps_2 = token_count_2 / generation_time_2 if generation_time_2 > 0 else 0 + generation_tps_1 = ( + token_count_1 / generation_time_1 if generation_time_1 > 0 else 0 + ) + generation_tps_2 = ( + token_count_2 / generation_time_2 if generation_time_2 > 0 else 0 + ) # Display time to first token profiling results print("\n=== Time to First Token Profiling ===") @@ -256,21 +260,35 @@ async def test_ttft( print(f" Total completion time: {total_time_1:.3f}s") print(f" Tokens generated: {token_count_1}") print(f" Response length: {len(response_string_1)} chars") - print(f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)") - print(f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)") + print( + f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)" + ) + print( + f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)" + ) print(f"\nSecond inference ('{task2.task_params.messages[0].content}'):") print(f" Time to first token: {time_to_first_token_2:.3f}s") print(f" Total completion time: {total_time_2:.3f}s") print(f" Tokens generated: {token_count_2}") print(f" Response length: {len(response_string_2)} chars") - print(f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)") - print(f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)") + print( + f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)" + ) + print( + f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)" + ) print("\nComparison:") - print(f" Second inference time to first token: {time_to_first_token_2/time_to_first_token_1:.2f}x the first") - print(f" Second inference prefill TPS: {prefill_tps_2/prefill_tps_1:.2f}x the first") - print(f" Second inference generation TPS: {generation_tps_2/generation_tps_1:.2f}x the first") + print( + f" Second inference time to first token: {time_to_first_token_2 / time_to_first_token_1:.2f}x the first" + ) + print( + f" Second inference prefill TPS: {prefill_tps_2 / prefill_tps_1:.2f}x the first" + ) + print( + f" Second inference generation TPS: {generation_tps_2 / generation_tps_1:.2f}x the first" + ) # Basic assertions to ensure responses make sense assert len(response_string_1) > 0 @@ -279,9 +297,86 @@ async def test_ttft( assert time_to_first_token_2 and time_to_first_token_2 > 0 # Cleanup - idx = await global_events.get_last_idx() + _ = global_events.collect() await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) + events = global_events.collect() + assert len(events) == 0 + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID, + ) + + await asyncio.sleep(2.0) + worker.shutdown() + + +@pytest.mark.skipif( + skip, + reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", +) +async def test_2_runner_inference( + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + chat_completion_task: Callable[[InstanceId, TaskId], Task], + two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], +): + worker1, worker2, global_events = two_workers_with_shared_mailbox + + async with create_task_group() as tg: + tg.start_soon(worker1.run) + tg.start_soon(worker2.run) + ## Instance + model_id = ModelId(MODEL_ID) + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1), + }, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, + ) + + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2), + ) + + task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + task.task_params.messages[ + 0 + ].content = "Can you explain to me how a bubble sort works, speaking as if you are a fairy." + task.task_params.max_tokens = 1000 + + await global_events.append_events( + [ + InstanceCreated(instance=instance), + TaskCreated(task_id=task.task_id, task=task), + ], + origin=MASTER_NODE_ID, + ) + + ( + seen_task_started, + seen_task_finished, + response_string, + _, + ) = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert "swap" in response_string.lower() + + _ = global_events.collect() + await asyncio.sleep(1.0) + events = global_events.collect() assert len(events) == 0 await global_events.append_events( @@ -295,118 +390,8 @@ async def test_ttft( await asyncio.sleep(2.0) - -@pytest.mark.skipif( - skip, - reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", -) -async def test_2_runner_inference( - logger: Logger, - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() - - global_events = event_log_manager.global_events - await global_events.delete_all_events() - - tasks: list[asyncio.Task[None]] = [] - - worker1 = Worker( - NODE_A, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker1))) - - worker2 = Worker( - NODE_B, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker2))) - - ## Instance - model_id = ModelId(MODEL_ID) - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[ - 0 - ].content = ( - "Can you explain to me how a bubble sort works, speaking as if you are a fairy." - ) - task.task_params.max_tokens = 1000 - - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert "swap" in response_string.lower() - - idx = await global_events.get_last_idx() - await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) - - for task in tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass # This is expected when we cancel a task - except Exception: - pass # Suppress any other exceptions during cleanup - - for worker in (worker1, worker2): - for assigned_runner in worker.assigned_runners.values(): - if assigned_runner.runner: - await assigned_runner.runner.astop() + worker1.shutdown() + worker2.shutdown() @pytest.mark.skipif( @@ -414,163 +399,132 @@ async def test_2_runner_inference( reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", ) async def test_parallel_inference( - logger: Logger, pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], + two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], ): - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() + worker1, worker2, global_events = two_workers_with_shared_mailbox - global_events = event_log_manager.global_events - await global_events.delete_all_events() + async with create_task_group() as tg: + tg.start_soon(worker1.run) + tg.start_soon(worker2.run) - tasks: list[asyncio.Task[None]] = [] + ## Instance + model_id = ModelId(MODEL_ID) - worker1 = Worker( - NODE_A, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker1))) + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1), + }, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, + ) - worker2 = Worker( - NODE_B, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - tasks.append(asyncio.create_task(run(worker2))) + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2), + ) - ## Instance - model_id = ModelId(MODEL_ID) + completion_create_params_1 = ChatCompletionTaskParams( + model="gpt-4", + messages=[ + ChatCompletionMessage( + role="user", content='Tell me a haiku that uses the word "pond".' + ) + ], + stream=True, + max_tokens=1000, + ) + task1 = ChatCompletionTask( + task_id=TASK_1_ID, + command_id=COMMAND_1_ID, + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=completion_create_params_1, + ) - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) + completion_create_params_2 = ChatCompletionTaskParams( + model="gpt-4", + messages=[ + ChatCompletionMessage( + role="user", content='Tell me a haiku that uses the word "tree".' + ) + ], + stream=True, + max_tokens=1000, + ) + task2 = ChatCompletionTask( + task_id=TASK_2_ID, + command_id=COMMAND_2_ID, + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=completion_create_params_2, + ) - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) + await global_events.append_events( + [ + InstanceCreated(instance=instance), + TaskCreated(task_id=task1.task_id, task=task1), + TaskCreated(task_id=task2.task_id, task=task2), + ], + origin=MASTER_NODE_ID, + ) - completion_create_params_1 = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content='Tell me a haiku that uses the word "pond".' - ) - ], - stream=True, - max_tokens=1000, - ) - task1 = ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=completion_create_params_1, - ) + ( + seen_task_started_1, + seen_task_finished_1, + response_string_1, + _, + ) = await read_streaming_response(global_events) - completion_create_params_2 = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content='Tell me a haiku that uses the word "tree".' - ) - ], - stream=True, - max_tokens=1000, - ) - task2 = ChatCompletionTask( - task_id=TASK_2_ID, - command_id=COMMAND_2_ID, - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=completion_create_params_2, - ) + incomplete_task = ( + TASK_2_ID + if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.COMPLETE + else TASK_2_ID + ) + ( + seen_task_started_2, + seen_task_finished_2, + response_string_2, + _, + ) = await read_streaming_response(global_events, filter_task=incomplete_task) - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task1.task_id, task=task1), - TaskCreated(task_id=task2.task_id, task=task2), - ], - origin=MASTER_NODE_ID, - ) + assert seen_task_started_1 + assert seen_task_finished_1 + assert seen_task_started_2 + assert seen_task_finished_2 - ( - seen_task_started_1, - seen_task_finished_1, - response_string_1, - _, - ) = await read_streaming_response(global_events) + print(response_string_1) + print(response_string_2) - incomplete_task = ( - TASK_2_ID - if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.COMPLETE - else TASK_2_ID - ) - ( - seen_task_started_2, - seen_task_finished_2, - response_string_2, - _, - ) = await read_streaming_response(global_events, filter_task=incomplete_task) + assert ("pond" in response_string_1.lower()) ^ ( + "pond" in response_string_2.lower() + ), "'pond' must appear in exactly one response" + assert ("tree" in response_string_1.lower()) ^ ( + "tree" in response_string_2.lower() + ), "'tree' must appear in exactly one response" - assert seen_task_started_1 - assert seen_task_finished_1 - assert seen_task_started_2 - assert seen_task_finished_2 + _ = global_events.collect() + await asyncio.sleep(1.0) + events = global_events.collect() + assert len(events) == 0 - print(response_string_1) - print(response_string_2) + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID, + ) - assert ("pond" in response_string_1.lower()) ^ ( - "pond" in response_string_2.lower() - ), "'pond' must appear in exactly one response" - assert ("tree" in response_string_1.lower()) ^ ( - "tree" in response_string_2.lower() - ), "'tree' must appear in exactly one response" + await asyncio.sleep(2.0) - idx = await global_events.get_last_idx() - await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) - - for task in tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass # This is expected when we cancel a task - except Exception: - pass # Suppress any other exceptions during cleanup - - for worker in (worker1, worker2): - for assigned_runner in worker.assigned_runners.values(): - if assigned_runner.runner: - await assigned_runner.runner.astop() + worker1.shutdown() + worker2.shutdown() diff --git a/src/exo/worker/tests/test_plan/test_worker_plan.py b/src/exo/worker/tests/test_plan/test_worker_plan.py index bbb59fc1..c04038a5 100644 --- a/src/exo/worker/tests/test_plan/test_worker_plan.py +++ b/src/exo/worker/tests/test_plan/test_worker_plan.py @@ -1,10 +1,5 @@ -from __future__ import annotations - -import logging - import pytest -from exo.shared.logging import logger_test_install from exo.shared.types.api import ChatCompletionMessage from exo.shared.types.state import State from exo.shared.types.tasks import ( @@ -13,7 +8,7 @@ from exo.shared.types.tasks import ( TaskStatus, TaskType, ) -from exo.shared.types.worker.common import NodeStatus +from exo.shared.types.worker.common import WorkerStatus from exo.shared.types.worker.downloads import ( DownloadPending, ) @@ -34,7 +29,6 @@ from exo.shared.types.worker.runners import ( ) from exo.shared.types.worker.shards import PipelineShardMetadata from exo.worker.common import AssignedRunner -from exo.worker.download.shard_downloader import NoopShardDownloader from exo.worker.main import Worker from exo.worker.plan import plan from exo.worker.tests.constants import ( @@ -74,7 +68,7 @@ def _get_test_cases() -> list[PlanTestCase]: description="no runners -> no-op", in_process_runners=[], state=State( - node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={} + node_status={NODE_A: WorkerStatus.Idle}, instances={}, runners={} ), expected_op=None, ), @@ -144,7 +138,7 @@ def _get_test_cases() -> list[PlanTestCase]: ) ], state=State( - node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={} + node_status={NODE_A: WorkerStatus.Idle}, instances={}, runners={} ), expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID), ), @@ -496,7 +490,7 @@ def _get_test_cases() -> list[PlanTestCase]: # We use a factory to delay test case generation until tmp_path is available. [pytest.param(c, id=c.id()) for c in _get_test_cases()], ) -def test_worker_plan(case: PlanTestCase) -> None: +def test_worker_plan(case: PlanTestCase, worker_void_mailbox: Worker) -> None: """Exercise Worker.plan across declarative scenarios.""" print(f"----- case: {case.description}") @@ -505,17 +499,7 @@ def test_worker_plan(case: PlanTestCase) -> None: test_cases = {c.description: c for c in _get_test_cases()} case = test_cases[case.description] - node_id = NODE_A - - logger = logging.getLogger("test_worker_plan") - logger_test_install(logger) - shard_downloader = NoopShardDownloader() - worker = Worker( - node_id=node_id, - shard_downloader=shard_downloader, - worker_events=None, - global_events=None, - ) + worker = worker_void_mailbox runner_config: InProcessRunner for runner_config in case.in_process_runners: @@ -532,7 +516,7 @@ def test_worker_plan(case: PlanTestCase) -> None: runner_node = node break - if runner_node != node_id: + if runner_node != worker.node_id: # This runner belongs to a different node, skip it continue diff --git a/src/exo/worker/tests/test_plan/test_worker_plan_utils.py b/src/exo/worker/tests/test_plan/test_worker_plan_utils.py index dce20444..4c7d12f9 100644 --- a/src/exo/worker/tests/test_plan/test_worker_plan_utils.py +++ b/src/exo/worker/tests/test_plan/test_worker_plan_utils.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dataclasses import dataclass from typing import List, NotRequired, Optional, TypedDict @@ -8,10 +6,11 @@ from typing_extensions import Literal from exo.shared.models.model_cards import MODEL_CARDS, ModelCard from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.memory import Memory from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.state import State from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType -from exo.shared.types.worker.common import InstanceId, NodeStatus, RunnerId +from exo.shared.types.worker.common import InstanceId, RunnerId, WorkerStatus from exo.shared.types.worker.downloads import DownloadOngoing, DownloadProgressData from exo.shared.types.worker.instances import Instance, InstanceStatus from exo.shared.types.worker.ops import RunnerOp @@ -117,7 +116,9 @@ def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus: return DownloadingRunnerStatus( download_progress=DownloadOngoing( node_id=node_id, - download_progress=DownloadProgressData(total_bytes=1, downloaded_bytes=0), + download_progress=DownloadProgressData( + total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0) + ), ) ) @@ -129,9 +130,9 @@ def make_model_meta(model_id: str) -> ModelMetadata: model_card = card return ModelMetadata( - model_id=model_id, + model_id=ModelId(model_id), pretty_name=model_card.model_id, - storage_size_kilobytes=10**6, + storage_size=Memory.from_kb(10**6), n_layers=16, ) @@ -146,7 +147,7 @@ def make_instance( runner_specs: list[tuple[RunnerId, NodeId, int, RunnerStatus]], model_id: ModelId = MODEL_A_ID, instance_status: InstanceStatus = InstanceStatus.ACTIVE, -) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, NodeStatus]]: +) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, WorkerStatus]]: """Creates an instance with one or more runners.""" runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {} node_to_runner: dict[NodeId, RunnerId] = {} @@ -170,13 +171,13 @@ def make_instance( ) # Currently nodes are only ever idle - as if they were running we would be blocking - so we wouldn't be running plan() - # node_statuses = {node_id: NodeStatus.Idle for _, node_id, _, _ in runner_specs} - node_statuses: dict[NodeId, NodeStatus] = {} + # node_statuses = {node_id: WorkerStatus.Idle for _, node_id, _, _ in runner_specs} + node_statuses: dict[NodeId, WorkerStatus] = {} for _runner_id, node_id, _, status in runner_specs: if isinstance(status, RunningRunnerStatus): - node_statuses[node_id] = NodeStatus.Running + node_statuses[node_id] = WorkerStatus.Running else: - node_statuses[node_id] = NodeStatus.Idle + node_statuses[node_id] = WorkerStatus.Idle runner_statuses = {runner_id: status for runner_id, _, _, status in runner_specs} return instance, runner_statuses, node_statuses @@ -195,7 +196,7 @@ def make_state( tasks = {} instances: dict[InstanceId, Instance] = {} all_runner_statuses: dict[RunnerId, RunnerStatus] = {} - all_node_statuses: dict[NodeId, NodeStatus] = {} + all_node_statuses: dict[NodeId, WorkerStatus] = {} for inst_id, specs in runner_specs_per_instance.items(): # Build per-instance data using make_instance diff --git a/src/exo/worker/tests/test_runner_connection.py b/src/exo/worker/tests/test_runner_connection.py index 29e2f1ba..0eccf5d3 100644 --- a/src/exo/worker/tests/test_runner_connection.py +++ b/src/exo/worker/tests/test_runner_connection.py @@ -1,20 +1,18 @@ import asyncio import os -from logging import Logger from typing import Callable import pytest +from anyio import create_task_group, move_on_after -from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from exo.shared.logging import logger_test_install from exo.shared.types.common import Host from exo.shared.types.events import InstanceCreated, InstanceDeleted from exo.shared.types.models import ModelId from exo.shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments from exo.shared.types.worker.runners import FailedRunnerStatus from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.download.shard_downloader import NoopShardDownloader -from exo.worker.main import run +from exo.worker.main import Worker +from exo.worker.runner.runner_supervisor import RunnerSupervisor from exo.worker.tests.constants import ( INSTANCE_1_ID, MASTER_NODE_ID, @@ -23,7 +21,7 @@ from exo.worker.tests.constants import ( RUNNER_1_ID, RUNNER_2_ID, ) -from exo.worker.worker import Worker +from exo.worker.tests.worker_management import WorkerMailbox @pytest.fixture @@ -36,43 +34,31 @@ def user_message() -> str: reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set", ) async def check_runner_connection( - logger: Logger, pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]], + two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], ) -> bool: - logger_test_install(logger) + async def wait_for_runner_supervisor( + worker: Worker, timeout: float = 5.0 + ) -> RunnerSupervisor | None: + with move_on_after(timeout): + while True: + assigned_runners = list(worker.assigned_runners.values()) + if assigned_runners: + runner = assigned_runners[0].runner + if isinstance(runner, RunnerSupervisor): + print("breaking because success") + return runner + if isinstance(assigned_runners[0].status, FailedRunnerStatus): + print("breaking because failed") + return runner + await asyncio.sleep(0.001) + + worker1, worker2, global_events = two_workers_with_shared_mailbox # Track all tasks and workers for cleanup - tasks: list[asyncio.Task[None]] = [] - workers: list[Worker] = [] - - try: - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() - - global_events = event_log_manager.global_events - await global_events.delete_all_events() - - worker1 = Worker( - NODE_A, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - workers.append(worker1) - task1 = asyncio.create_task(run(worker1)) - tasks.append(task1) - - worker2 = Worker( - NODE_B, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - workers.append(worker2) - task2 = asyncio.create_task(run(worker2)) - tasks.append(task2) - + async with create_task_group() as tg: + tg.start_soon(worker1.run) + tg.start_soon(worker2.run) model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") shard_assignments = ShardAssignments( @@ -98,28 +84,11 @@ async def check_runner_connection( origin=MASTER_NODE_ID, ) - from exo.worker.runner.runner_supervisor import RunnerSupervisor - - async def wait_for_runner_supervisor( - worker: Worker, timeout: float = 5.0 - ) -> RunnerSupervisor | None: - end = asyncio.get_event_loop().time() + timeout - while True: - assigned_runners = list(worker.assigned_runners.values()) - if assigned_runners: - runner = assigned_runners[0].runner - if isinstance(runner, RunnerSupervisor): - print("breaking because success") - return runner - if isinstance(assigned_runners[0].status, FailedRunnerStatus): - print("breaking because failed") - return runner - if asyncio.get_event_loop().time() > end: - raise TimeoutError("RunnerSupervisor was not set within timeout") - await asyncio.sleep(0.001) - runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0) - ret = runner_supervisor is not None and runner_supervisor.runner_process.is_alive() + ret = ( + runner_supervisor is not None + and runner_supervisor.runner_process.is_alive() + ) await global_events.append_events( [ @@ -132,14 +101,13 @@ async def check_runner_connection( await asyncio.sleep(0.5) - return ret - finally: - # Cancel all worker tasks - for task in tasks: - task.cancel() + worker1.shutdown() + worker2.shutdown() + tg.cancel_scope.cancel() - # Wait for cancellation to complete - await asyncio.gather(*tasks, return_exceptions=True) + return ret + # should be unreachable + raise # Check Running status @@ -147,7 +115,6 @@ async def check_runner_connection( # # not now. # def test_runner_connection_stress( -# logger: Logger, # pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], # hosts: Callable[[int], list[Host]], # chat_completion_task: Callable[[InstanceId, str], Task], @@ -157,12 +124,10 @@ async def check_runner_connection( # # not now. # def test_runner_connection_stress( -# logger: Logger, # pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], # hosts: Callable[[int], list[Host]], # chat_completion_task: Callable[[InstanceId, str], Task], # ) -> None: -# logger_test_install(logger) # total_runs = 100 # successes = 0 diff --git a/src/exo/worker/tests/test_serdes.py b/src/exo/worker/tests/test_serdes.py index 4239b17d..bee86310 100644 --- a/src/exo/worker/tests/test_serdes.py +++ b/src/exo/worker/tests/test_serdes.py @@ -1,4 +1,4 @@ -from typing import Callable, TypeVar +from typing import Callable from pydantic import BaseModel, TypeAdapter @@ -12,10 +12,8 @@ from exo.shared.types.worker.commands_runner import ( from exo.shared.types.worker.common import InstanceId from exo.shared.types.worker.shards import PipelineShardMetadata -T = TypeVar("T", bound=BaseModel) - -def assert_equal_serdes(obj: T, typeadapter: TypeAdapter[T]): +def assert_equal_serdes[T: BaseModel](obj: T, typeadapter: TypeAdapter[T]): encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n" decoded: T = typeadapter.validate_json(encoded) diff --git a/src/exo/worker/tests/test_spinup_timeout.py b/src/exo/worker/tests/test_spinup_timeout.py index 501ca649..3780023a 100644 --- a/src/exo/worker/tests/test_spinup_timeout.py +++ b/src/exo/worker/tests/test_spinup_timeout.py @@ -7,8 +7,8 @@ import pytest from exo.shared.types.events import ( Event, + RunnerStatusUpdated, ) -from exo.shared.types.events._events import RunnerStatusUpdated from exo.shared.types.tasks import Task, TaskId from exo.shared.types.worker.instances import Instance, InstanceId from exo.shared.types.worker.ops import ( diff --git a/src/exo/worker/tests/test_supervisor/test_long.py b/src/exo/worker/tests/test_supervisor/test_long.py index 51381ba5..89f81969 100644 --- a/src/exo/worker/tests/test_supervisor/test_long.py +++ b/src/exo/worker/tests/test_supervisor/test_long.py @@ -1,14 +1,12 @@ import asyncio -from logging import Logger from typing import Callable import pytest -from exo.shared.logging import logger_test_install from exo.shared.models.model_cards import MODEL_CARDS from exo.shared.openai_compat import FinishReason +from exo.shared.types.chunks import TokenChunk from exo.shared.types.common import Host -from exo.shared.types.events.chunks import TokenChunk from exo.shared.types.tasks import ( Task, TaskId, @@ -23,6 +21,7 @@ def user_message(): """Override the default message to ask about France's capital""" return "What is the capital of France?" + @pytest.fixture def lorem_ipsum() -> str: return """ @@ -48,18 +47,17 @@ Curabitur non vehicula purus. Cras et justo risus. Duis et rutrum urna. Aliquam Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Praesent porttitor tempor ligula. Quisque mollis arcu in metus ornare pellentesque. Aenean ultrices mollis quam quis sodales. Maecenas a cursus elit, id gravida tortor. Donec vel purus magna. Aliquam elementum est sed convallis fermentum. Nam nec eros arcu. Pellentesque sed eros a lacus sagittis maximus. Integer et tellus id libero dapibus convallis. Maecenas viverra, purus facilisis porttitor tincidunt, tellus lacus elementum dui, sed porttitor sem justo a lorem. Curabitur ipsum odio, efficitur quis efficitur at, tempus aliquet nisi. Aliquam ultrices tortor in arcu vulputate, vel iaculis lorem facilisis. Cras eleifend laoreet feugiat. Integer placerat blandit sem, mattis elementum purus pellentesque quis. Etiam vel arcu ut mi commodo placerat non id tortor. """ + @pytest.mark.asyncio async def test_supervisor_long_prompt_response( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], lorem_ipsum: str, - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) - model_meta = MODEL_CARDS['llama-3.2-1b'].metadata + model_meta = MODEL_CARDS["llama-3.2-1b"].metadata model_shard_meta = PipelineShardMetadata( model_meta=model_meta, device_rank=0, @@ -83,10 +81,7 @@ async def test_supervisor_long_prompt_response( task = chat_completion_task(instance_id, TaskId()) task.task_params.messages[0].content = lorem_ipsum * 3 - - async for chunk in supervisor.stream_response( - task=task - ): + async for chunk in supervisor.stream_response(task=task): if isinstance(chunk, TokenChunk): full_response += chunk.text @@ -102,21 +97,21 @@ async def test_supervisor_two_node_long_prompt_response( hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], lorem_ipsum: str, - logger: Logger, ): """Test two-node long prompt inference""" - logger_test_install(logger) instance_id = InstanceId() async def create_supervisor(shard_idx: int) -> RunnerSupervisor: - model_meta = MODEL_CARDS['llama-3.2-1b'].metadata + model_meta = MODEL_CARDS["llama-3.2-1b"].metadata model_shard_meta = PipelineShardMetadata( model_meta=model_meta, device_rank=shard_idx, world_size=2, n_layers=model_meta.n_layers, start_layer=0 if shard_idx == 0 else model_meta.n_layers // 2, - end_layer=model_meta.n_layers // 2 if shard_idx == 0 else model_meta.n_layers, + end_layer=model_meta.n_layers // 2 + if shard_idx == 0 + else model_meta.n_layers, ) supervisor = await RunnerSupervisor.create( model_shard_meta=model_shard_meta, @@ -166,4 +161,3 @@ async def test_supervisor_two_node_long_prompt_response( finally: await supervisor_0.astop() await supervisor_1.astop() - diff --git a/src/exo/worker/tests/test_supervisor/test_memory.py b/src/exo/worker/tests/test_supervisor/test_memory.py index e250e5a4..140923a2 100644 --- a/src/exo/worker/tests/test_supervisor/test_memory.py +++ b/src/exo/worker/tests/test_supervisor/test_memory.py @@ -1,11 +1,9 @@ -from logging import Logger from multiprocessing import Process from typing import Callable import psutil import pytest -from exo.shared.logging import logger_test_install from exo.shared.models.model_meta import get_model_meta from exo.shared.types.common import Host from exo.shared.types.models import ModelMetadata @@ -35,9 +33,7 @@ async def test_supervisor_inference_exception( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, ): - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) supervisor = await RunnerSupervisor.create( diff --git a/src/exo/worker/tests/test_supervisor/test_oom.py b/src/exo/worker/tests/test_supervisor/test_oom.py index 9b1b4778..8ea4c2b8 100644 --- a/src/exo/worker/tests/test_supervisor/test_oom.py +++ b/src/exo/worker/tests/test_supervisor/test_oom.py @@ -1,9 +1,7 @@ -from logging import Logger from typing import Callable import pytest -from exo.shared.logging import logger_test_install from exo.shared.types.common import Host from exo.shared.types.tasks import ( Task, @@ -29,9 +27,7 @@ async def test_supervisor_catches_oom( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, ): - logger_test_install(logger) """Test that asking for the capital of France returns 'Paris' in the response""" model_shard_meta = pipeline_shard_meta(1, 0) diff --git a/src/exo/worker/tests/test_supervisor/test_supervisor.py b/src/exo/worker/tests/test_supervisor/test_supervisor.py index 1a7f7fb3..6b44c9b9 100644 --- a/src/exo/worker/tests/test_supervisor/test_supervisor.py +++ b/src/exo/worker/tests/test_supervisor/test_supervisor.py @@ -1,13 +1,11 @@ import asyncio -from logging import Logger from typing import Callable import pytest -from exo.shared.logging import logger_test_install from exo.shared.openai_compat import FinishReason +from exo.shared.types.chunks import TokenChunk from exo.shared.types.common import Host -from exo.shared.types.events.chunks import TokenChunk from exo.shared.types.tasks import ( ChatCompletionTaskParams, Task, @@ -30,10 +28,8 @@ async def test_supervisor_single_node_response( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) instance_id = InstanceId() @@ -71,10 +67,8 @@ async def test_supervisor_two_node_response( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) instance_id = InstanceId() async def create_supervisor(shard_idx: int) -> RunnerSupervisor: @@ -136,10 +130,8 @@ async def test_supervisor_early_stopping( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) instance_id = InstanceId() @@ -190,10 +182,8 @@ async def test_supervisor_early_stopping( async def test_supervisor_handles_terminated_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - logger: Logger, ): """Test that the supervisor handles a terminated runner""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) supervisor = await RunnerSupervisor.create( @@ -214,10 +204,8 @@ async def test_supervisor_handles_terminated_runner( async def test_supervisor_handles_killed_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - logger: Logger, ): """Test that the supervisor handles a killed runner""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) supervisor = await RunnerSupervisor.create( diff --git a/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py b/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py index 87a06273..11d24f2b 100644 --- a/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py +++ b/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py @@ -1,10 +1,8 @@ import asyncio -from logging import Logger from typing import Callable import pytest -from exo.shared.logging import logger_test_install from exo.shared.types.common import Host from exo.shared.types.tasks import Task, TaskId from exo.shared.types.worker.common import InstanceId, RunnerError @@ -17,10 +15,8 @@ from exo.worker.tests.constants import INSTANCE_1_ID, TASK_1_ID async def test_supervisor_instantiation_exception( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) model_shard_meta.immediate_exception = True @@ -40,10 +36,8 @@ async def test_supervisor_instantiation_exception( async def test_supervisor_instantiation_timeout( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) model_shard_meta.should_timeout = 10 # timeout after 10s @@ -59,10 +53,8 @@ async def test_supervisor_inference_exception( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) supervisor = await RunnerSupervisor.create( @@ -82,10 +74,8 @@ async def test_supervisor_inference_timeout( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" - logger_test_install(logger) model_shard_meta = pipeline_shard_meta(1, 0) supervisor = await RunnerSupervisor.create( diff --git a/src/exo/worker/tests/worker_management.py b/src/exo/worker/tests/worker_management.py new file mode 100644 index 00000000..34b6db13 --- /dev/null +++ b/src/exo/worker/tests/worker_management.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass +from typing import Callable + +from anyio import fail_after + +from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent +from exo.shared.types.chunks import TokenChunk +from exo.shared.types.common import NodeId +from exo.shared.types.events import ChunkGenerated, Event, TaggedEvent, TaskStateUpdated +from exo.shared.types.tasks import TaskId, TaskStatus +from exo.utils.channels import Receiver, Sender, channel +from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader +from exo.worker.main import Worker + + +@dataclass +class WorkerMailbox: + sender: Sender[ForwarderEvent] + receiver: Receiver[ForwarderEvent] + counter: int = 0 + + async def append_events(self, events: list[Event], *, origin: NodeId): + for event in events: + await self.sender.send( + ForwarderEvent( + origin=origin, + tagged_event=TaggedEvent.from_(event), + origin_idx=self.counter, + ) + ) + self.counter += 1 + + async def receive(self) -> ForwarderEvent: + return await self.receiver.receive() + + def collect(self) -> list[ForwarderEvent]: + # Clear out the test mailboxes currently held events + return self.receiver.collect() + + +def create_worker_void_mailbox( + node_id: NodeId, shard_downloader: ShardDownloader | None = None +) -> Worker: + if shard_downloader is None: + shard_downloader = NoopShardDownloader() + return Worker( + node_id, + shard_downloader=shard_downloader, + initial_connection_messages=[], + connection_message_receiver=channel[ConnectionMessage]()[1], + global_event_receiver=channel[ForwarderEvent]()[1], + local_event_sender=channel[ForwarderEvent]()[0], + command_sender=channel[ForwarderCommand]()[0], + ) + + +def create_worker_and_mailbox( + node_id: NodeId, shard_downloader: ShardDownloader | None = None +) -> tuple[Worker, WorkerMailbox]: + if shard_downloader is None: + shard_downloader = NoopShardDownloader() + + lsend, receiver = channel[ForwarderEvent]() + sender, grecv = channel[ForwarderEvent]() + worker = Worker( + node_id, + shard_downloader=shard_downloader, + initial_connection_messages=[], + connection_message_receiver=channel[ConnectionMessage]()[1], + global_event_receiver=grecv, + local_event_sender=lsend, + command_sender=channel[ForwarderCommand]()[0], + ) + return worker, WorkerMailbox(sender, receiver) + + +def create_worker_with_old_mailbox( + node_id: NodeId, + mailbox: WorkerMailbox, + shard_downloader: ShardDownloader | None = None, +) -> Worker: + if shard_downloader is None: + shard_downloader = NoopShardDownloader() + # This function is subtly complex, come talk to Evan if you want to know what it's actually doing. + worker = Worker( + node_id, + shard_downloader=shard_downloader, + initial_connection_messages=[], + connection_message_receiver=channel[ConnectionMessage]()[1], + global_event_receiver=mailbox.sender.clone_receiver(), + local_event_sender=mailbox.receiver.clone_sender(), + command_sender=channel[ForwarderCommand]()[0], + ) + return worker + + +async def read_streaming_response( + global_event_receiver: WorkerMailbox, filter_task: TaskId | None = None +) -> tuple[bool, bool, str, int]: + # Read off all events - these should be our GenerationChunk events + seen_task_started = 0 + seen_task_finished = 0 + response_string = "" + finish_reason: str | None = None + token_count = 0 + extra_events: list[Event] = [] + + event = (await global_event_receiver.receive()).tagged_event.c + extra_events.append(event) + + from loguru import logger + + logger.info("STARTING READ") + + with fail_after(10.0): + if filter_task: + while not ( + isinstance(event, TaskStateUpdated) + and event.task_status == TaskStatus.RUNNING + and event.task_id == filter_task + ): + event = (await global_event_receiver.receive()).tagged_event.c + extra_events.append(event) + + for event in extra_events: + if isinstance(event, TaskStateUpdated): + if event.task_status == TaskStatus.RUNNING: + seen_task_started += 1 + if event.task_status == TaskStatus.COMPLETE: + seen_task_finished += 1 + if isinstance(event, ChunkGenerated) and isinstance( + event.chunk, TokenChunk + ): + response_string += event.chunk.text + token_count += 1 + if event.chunk.finish_reason: + finish_reason = event.chunk.finish_reason + + while not seen_task_finished: + event = (await global_event_receiver.receive()).tagged_event.c + if isinstance(event, TaskStateUpdated): + if event.task_status == TaskStatus.RUNNING: + seen_task_started += 1 + if event.task_status == TaskStatus.COMPLETE: + seen_task_finished += 1 + if isinstance(event, ChunkGenerated) and isinstance( + event.chunk, TokenChunk + ): + response_string += event.chunk.text + token_count += 1 + if event.chunk.finish_reason: + finish_reason = event.chunk.finish_reason + + logger.info(f"finish reason {finish_reason}") + + return seen_task_started == 1, seen_task_finished == 1, response_string, token_count + + +async def until_event_with_timeout[T]( + global_event_receiver: WorkerMailbox, + event_type: type[T], + multiplicity: int = 1, + condition: Callable[[T], bool] = lambda x: True, + timeout: float = 30.0, +) -> None: + times_seen = 0 + + with fail_after(timeout): + while times_seen < multiplicity: + event = (await global_event_receiver.receive()).tagged_event.c + if isinstance(event, event_type): + print(f"Wow! We got a {event}") + print( + f"But condition? {condition(event) if isinstance(event, event_type) else False}" + ) + if event and isinstance(event, event_type) and condition(event): + times_seen += 1 diff --git a/src/exo/worker/utils/profile.py b/src/exo/worker/utils/profile.py index ab4d3e33..174c1a41 100644 --- a/src/exo/worker/utils/profile.py +++ b/src/exo/worker/utils/profile.py @@ -3,6 +3,7 @@ import os import platform from typing import Any, Callable, Coroutine +import anyio from loguru import logger from exo.shared.types.profiling import ( @@ -75,7 +76,7 @@ async def start_polling_node_metrics( chip_id=system_info.chip_id, friendly_name=mac_friendly_name or "Unknown", network_interfaces=network_interfaces, - memory=MemoryPerformanceProfile( + memory=MemoryPerformanceProfile.from_bytes( ram_total=total_mem, ram_available=override_memory if override_memory @@ -125,4 +126,4 @@ async def start_polling_node_metrics( # Catch-all to ensure the monitor keeps running. logger.opt(exception=e).error("Resource Monitor encountered error") finally: - await asyncio.sleep(poll_interval_s) + await anyio.sleep(poll_interval_s) diff --git a/src/exo/worker/worker.py b/src/exo/worker/worker.py deleted file mode 100644 index 606f487a..00000000 --- a/src/exo/worker/worker.py +++ /dev/null @@ -1,429 +0,0 @@ -import asyncio -import time -from asyncio import Queue -from functools import partial -from typing import AsyncGenerator, Optional - -from loguru import logger - -from exo.shared.db.sqlite import AsyncSQLiteEventStorage -from exo.shared.types.common import NodeId -from exo.shared.types.events import ( - ChunkGenerated, - Event, - InstanceDeleted, - RunnerDeleted, - RunnerStatusUpdated, - TaskFailed, - TaskStateUpdated, -) -from exo.shared.types.state import State -from exo.shared.types.tasks import TaskId, TaskStatus -from exo.shared.types.worker.common import RunnerId -from exo.shared.types.worker.downloads import ( - DownloadCompleted, - DownloadOngoing, - DownloadPending, - DownloadProgressData, -) -from exo.shared.types.worker.ops import ( - AssignRunnerOp, - ExecuteTaskOp, - RunnerDownOp, - RunnerFailedOp, - RunnerOp, - RunnerOpType, - RunnerUpOp, - UnassignRunnerOp, -) -from exo.shared.types.worker.runners import ( - DownloadingRunnerStatus, - FailedRunnerStatus, - InactiveRunnerStatus, - LoadedRunnerStatus, - RunningRunnerStatus, - StartingRunnerStatus, -) -from exo.shared.types.worker.shards import ShardMetadata -from exo.worker.common import AssignedRunner -from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader -from exo.worker.runner.runner_supervisor import RunnerSupervisor - - -class Worker: - def __init__( - self, - node_id: NodeId, - shard_downloader: ShardDownloader, - worker_events: AsyncSQLiteEventStorage | None, - global_events: AsyncSQLiteEventStorage | None, - ): - self.node_id: NodeId = node_id - self.state: State = State() - self.shard_downloader: ShardDownloader = shard_downloader - self.worker_events: AsyncSQLiteEventStorage | None = ( - worker_events # worker_events is None in some tests. - ) - self.global_events: AsyncSQLiteEventStorage | None = global_events - - self.assigned_runners: dict[RunnerId, AssignedRunner] = {} - self._task: asyncio.Task[None] | None = None - - ## Op Executors - - def _create_assigned_runner(self, op: AssignRunnerOp) -> AssignedRunner: - """Creates and stores a new AssignedRunner with initial downloading status.""" - assigned_runner = AssignedRunner( - runner_id=op.runner_id, - instance_id=op.instance_id, - shard_metadata=op.shard_metadata, - hosts=op.hosts, - status=DownloadingRunnerStatus( - download_progress=DownloadPending(node_id=self.node_id) - ), - runner=None, - ) - self.assigned_runners[op.runner_id] = assigned_runner - return assigned_runner - - async def _update_runner_status_to_completed_then_inactive( - self, assigned_runner: AssignedRunner - ) -> AsyncGenerator[Event, None]: - """Updates runner status from downloading to completed, then to inactive.""" - assigned_runner.status = DownloadingRunnerStatus( - download_progress=DownloadCompleted(node_id=self.node_id) - ) - yield assigned_runner.status_update_event() - - assigned_runner.status = InactiveRunnerStatus() - yield assigned_runner.status_update_event() - - async def _handle_already_downloaded_shard( - self, assigned_runner: AssignedRunner - ) -> AsyncGenerator[Event, None]: - """Handles the case where the shard is already downloaded.""" - async for event in self._update_runner_status_to_completed_then_inactive( - assigned_runner - ): - yield event - - async def _handle_shard_download_process( - self, - assigned_runner: AssignedRunner, - op: AssignRunnerOp, - initial_progress: RepoDownloadProgress, - ) -> AsyncGenerator[Event, None]: - """Manages the shard download process with progress tracking.""" - # Set initial ongoing status - assigned_runner.status = DownloadingRunnerStatus( - download_progress=DownloadOngoing( - node_id=self.node_id, - download_progress=DownloadProgressData( - total_bytes=initial_progress.total_bytes, - downloaded_bytes=initial_progress.downloaded_bytes, - ), - ) - ) - yield assigned_runner.status_update_event() - - # Set up download progress tracking - download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue() - - def download_progress_callback( - shard: ShardMetadata, progress: RepoDownloadProgress - ) -> None: - download_progress_queue.put_nowait(progress) - - self.shard_downloader.on_progress(download_progress_callback) - download_task = asyncio.create_task( - self.shard_downloader.ensure_shard(op.shard_metadata) - ) - - try: - async for event in self._monitor_download_progress( - assigned_runner, download_progress_queue - ): - yield event - finally: - if not download_task.done(): - download_task.cancel() - - async def _monitor_download_progress( - self, - assigned_runner: AssignedRunner, - download_progress_queue: asyncio.Queue[RepoDownloadProgress], - ) -> AsyncGenerator[Event, None]: - """Monitors download progress and yields status updates.""" - last_progress_time = 0.0 - throttle_interval_secs = 1.0 - - while True: - progress: RepoDownloadProgress = await asyncio.wait_for( - download_progress_queue.get(), timeout=15 - ) - - if progress.status == "complete": - async for ( - event - ) in self._update_runner_status_to_completed_then_inactive( - assigned_runner - ): - yield event - break - elif progress.status == "in_progress": - if time.monotonic() - last_progress_time > throttle_interval_secs: - assigned_runner.status = DownloadingRunnerStatus( - download_progress=DownloadOngoing( - node_id=self.node_id, - download_progress=DownloadProgressData( - total_bytes=progress.total_bytes, - downloaded_bytes=progress.downloaded_bytes, - ), - ) - ) - yield assigned_runner.status_update_event() - last_progress_time = time.monotonic() - - async def _execute_assign_op( - self, op: AssignRunnerOp - ) -> AsyncGenerator[Event, None]: - """ - A runner has been assigned. We need to also ensure that it's downloaded. - This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state. - """ - assigned_runner = self._create_assigned_runner(op) - initial_progress = ( - await self.shard_downloader.get_shard_download_status_for_shard( - op.shard_metadata - ) - ) - - if initial_progress.status == "complete": - async for event in self._handle_already_downloaded_shard(assigned_runner): - yield event - else: - async for event in self._handle_shard_download_process( - assigned_runner, op, initial_progress - ): - yield event - - async def _execute_unassign_op( - self, op: UnassignRunnerOp - ) -> AsyncGenerator[Event, None]: - if op.runner_id not in self.assigned_runners: - return - - # We can try to do a graceful shutdown of the runner. - runner: RunnerSupervisor | None = self.assigned_runners[op.runner_id].runner - if runner is not None: - await runner.astop() - - # This is all we really need: - del self.assigned_runners[op.runner_id] - yield RunnerDeleted(runner_id=op.runner_id) - - return - yield - - async def _execute_runner_up_op( - self, op: RunnerUpOp, initialize_timeout: Optional[float] = None - ) -> AsyncGenerator[Event, None]: - assigned_runner = self.assigned_runners[op.runner_id] - - # Emit "Starting" status right away so UI can show loading state - assigned_runner.status = StartingRunnerStatus() - yield assigned_runner.status_update_event() - - assigned_runner.runner = await RunnerSupervisor.create( - model_shard_meta=assigned_runner.shard_metadata, - hosts=assigned_runner.hosts, - initialize_timeout=initialize_timeout, - ) - - if assigned_runner.runner.runner_process.is_alive(): - assigned_runner.status = LoadedRunnerStatus() - else: - runner = assigned_runner.runner - logger.warning(f"Runner status is not runner_process.is_alive(): exit code {runner.runner_process.exitcode}") - - assigned_runner.status = FailedRunnerStatus() - yield self.assigned_runners[op.runner_id].status_update_event() - - async def _execute_runner_down_op( - self, op: RunnerDownOp - ) -> AsyncGenerator[Event, None]: - assigned_runner = self.assigned_runners[op.runner_id] - - if isinstance(assigned_runner.runner, RunnerSupervisor): - await assigned_runner.runner.astop() - - assigned_runner.runner = None - - assigned_runner.status = InactiveRunnerStatus() - yield assigned_runner.status_update_event() - return - - async def _execute_runner_failed_op( - self, op: RunnerFailedOp - ) -> AsyncGenerator[Event, None]: - """ - We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down. - """ - assigned_runner = self.assigned_runners[op.runner_id] - - if isinstance(assigned_runner.runner, RunnerSupervisor): - await ( - assigned_runner.runner.astop() - ) # astop the runner to ensure it clears out of memory. - - assigned_runner.status = FailedRunnerStatus() - yield self.assigned_runners[op.runner_id].status_update_event() - - async def _execute_task_op(self, op: ExecuteTaskOp) -> AsyncGenerator[Event, None]: - """ - This is the entry point for a chat completion starting. - While there is only one execute function, it will get called in different ways for runner 0 and runner [1, 2, 3, ...]. - Runners [1, 2, 3, ...] will run this method when a task is in 'pending' state. - Runner 0 will run this method when a task is in 'running' state. - TODO: How do we handle the logic of ensuring that n-1 nodes have started their execution before allowing the 0'th runner to start? - This is still a little unclear to me. - """ - assigned_runner = self.assigned_runners[op.runner_id] - - async def inner_execute(queue: asyncio.Queue[Event]) -> None: - async def running_callback(queue: asyncio.Queue[Event]) -> None: - # Called when the MLX process has been kicked off - assigned_runner.status = RunningRunnerStatus() - await queue.put(assigned_runner.status_update_event()) - - if assigned_runner.shard_metadata.device_rank == 0: - await queue.put( - TaskStateUpdated( - task_id=op.task.task_id, - task_status=TaskStatus.RUNNING, - ) - ) - - assert assigned_runner.runner is not None - assert assigned_runner.runner.runner_process.is_alive() - - async for chunk in assigned_runner.runner.stream_response( - task=op.task, request_started_callback=partial(running_callback, queue) - ): - if assigned_runner.shard_metadata.device_rank == 0: - await queue.put( - ChunkGenerated( - # todo: at some point we will no longer have a bijection between task_id and row_id. - # So we probably want to store a mapping between these two in our Worker object. - command_id=chunk.command_id, - chunk=chunk, - ) - ) - - if assigned_runner.shard_metadata.device_rank == 0: - await queue.put( - TaskStateUpdated( - task_id=op.task.task_id, - task_status=TaskStatus.COMPLETE, - ) - ) - - # After a successful inference: - assigned_runner.status = LoadedRunnerStatus() - await queue.put(assigned_runner.status_update_event()) - - queue: Queue[Event] = asyncio.Queue() - task = asyncio.create_task(inner_execute(queue)) - - # TODO: Initial (prefil) timeout can be dynamic - # model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes - - try: - # Yield items from the queue - while True: - if task.done() and (exception := task.exception()): - raise exception - - try: - # Use a timeout to periodically check task status - item: Event = await asyncio.wait_for(queue.get(), timeout=0.01) - except asyncio.TimeoutError: - continue - - yield item - if isinstance(item, RunnerStatusUpdated) and isinstance( - item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus) - ): - if isinstance(item.runner_status, LoadedRunnerStatus): - assigned_runner.failures = [] - - break - finally: - # Ensure the task is cleaned up - try: - await asyncio.wait_for(task, timeout=5) - except asyncio.TimeoutError: - logger.warning( - "Timed out waiting for task cleanup after inference execution." - ) - - ## Operation Planner - - async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]: - ## It would be great if we can get rid of this async for ... yield pattern. - match op.op_type: - case RunnerOpType.ASSIGN_RUNNER: - event_generator = self._execute_assign_op(op) - case RunnerOpType.UNASSIGN_RUNNER: - event_generator = self._execute_unassign_op(op) - case RunnerOpType.RUNNER_UP: - event_generator = self._execute_runner_up_op(op) - case RunnerOpType.RUNNER_DOWN: - event_generator = self._execute_runner_down_op(op) - case RunnerOpType.RUNNER_FAILED: - event_generator = self._execute_runner_failed_op(op) - case RunnerOpType.CHAT_COMPLETION: - event_generator = self._execute_task_op(op) - - async for event in event_generator: - yield event - - async def fail_runner( - self, e: Exception, runner_id: RunnerId - ) -> AsyncGenerator[Event]: - if runner_id in self.assigned_runners: - assigned_runner = self.assigned_runners[runner_id] - - if assigned_runner.runner is not None: - await assigned_runner.runner.astop() - assigned_runner.runner = None - assigned_runner.status = FailedRunnerStatus(error_message=str(e)) - assigned_runner.failures.append((time.time(), e)) - - # Reset failure count back to 0 when succesful - if len(assigned_runner.failures) >= 3: - # Too many retries. We will emit a DeleteInstance - yield InstanceDeleted(instance_id=assigned_runner.instance_id) - - yield assigned_runner.status_update_event() - - async def fail_task( - self, e: Exception, runner_id: RunnerId, task_id: TaskId - ) -> AsyncGenerator[Event]: - if runner_id in self.assigned_runners: - yield TaskStateUpdated( - task_id=task_id, - task_status=TaskStatus.FAILED, - ) - - yield TaskFailed( - task_id=task_id, error_type=str(type(e)), error_message=str(e) - ) - - async for event in self.fail_runner(e, runner_id): - yield event - - async def event_publisher(self, event: Event) -> None: - assert self.worker_events is not None - await self.worker_events.append_events([event], self.node_id) - logger.info(f"published event: {event}") diff --git a/uv.lock b/uv.lock index 888d683e..798b19d4 100644 --- a/uv.lock +++ b/uv.lock @@ -13,6 +13,7 @@ supported-markers = [ [manifest] members = [ "exo", + "exo-pyo3-bindings", "exo-scripts", ] @@ -130,6 +131,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4a/45/ec96b29162a402fc4c1c5512d114d7b3787b9d1c2ec241d9568b4816ee23/base58-2.1.1-py3-none-any.whl", hash = "sha256:11a36f4d3ce51dfc1043f3218591ac4eb1ceb172919cebe05b52a5bcc8d245c2", size = 5621, upload-time = "2021-10-30T22:12:16.658Z" }, ] +[[package]] +name = "bidict" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093, upload-time = "2024-02-18T19:09:05.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764, upload-time = "2024-02-18T19:09:04.156Z" }, +] + [[package]] name = "certifi" version = "2025.8.3" @@ -249,9 +259,12 @@ dependencies = [ { name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "aiosqlite", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "base58", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "bidict", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "cobs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "cryptography", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -293,9 +306,12 @@ requires-dist = [ { name = "aiofiles", specifier = ">=24.1.0" }, { name = "aiohttp", specifier = ">=3.12.14" }, { name = "aiosqlite", specifier = ">=0.21.0" }, + { name = "anyio", specifier = ">=4.10.0" }, { name = "base58", specifier = ">=2.1.1" }, + { name = "bidict", specifier = ">=0.23.1" }, { name = "cobs", specifier = ">=1.2.2" }, { name = "cryptography", specifier = ">=45.0.5" }, + { name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" }, { name = "fastapi", specifier = ">=0.116.1" }, { name = "filelock", specifier = ">=3.18.0" }, { name = "greenlet", specifier = ">=3.2.4" }, @@ -329,6 +345,27 @@ dev = [ { name = "ruff", specifier = ">=0.11.13" }, ] +[[package]] +name = "exo-pyo3-bindings" +version = "0.1.0" +source = { editable = "rust/exo_pyo3_bindings" } + +[package.dev-dependencies] +dev = [ + { name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pytest-asyncio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] + +[package.metadata] + +[package.metadata.requires-dev] +dev = [ + { name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" }, + { name = "pytest", specifier = ">=8.4.0" }, + { name = "pytest-asyncio", specifier = ">=1.0.0" }, +] + [[package]] name = "exo-scripts" version = "0.1.0"