From 0e32599e71628175d207d1e3257a72b98e84cad2 Mon Sep 17 00:00:00 2001 From: Gelu Vrabie Date: Thu, 31 Jul 2025 20:36:47 +0100 Subject: [PATCH] fix libp2p + other prs that were wrongly overwritten before (111,112,117,118,1119 + misc commits from Alex) Co-authored-by: Gelu Vrabie Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com> Co-authored-by: Seth Howes <71157822+sethhowes@users.noreply.github.com> Co-authored-by: Matt Beton Co-authored-by: Alex Cheema --- .github/actions/format/action.yml | 2 +- .github/actions/lint-check/action.yml | 2 +- .github/actions/lint/action.yml | 2 +- .../actions/regenerate-protobufs/action.yml | 2 +- .github/actions/typecheck/action.yml | 6 +- .github/actions/unit-test/action.yml | 12 + .github/workflows/e2e_test.yml | 360 +++++++ .github/workflows/pipeline.yml | 111 ++- engines/mlx/utils_mlx.py | 4 + justfile | 3 + master/discovery_supervisor.py | 7 +- master/forwarder_supervisor.py | 1 + master/main.py | 21 +- master/tests/test_topology.py | 17 +- networking/forwarder/go.mod | 98 +- networking/forwarder/go.sum | 253 ++--- networking/forwarder/main.go | 7 + networking/forwarder/src/event_writer.go | 259 +++++ networking/forwarder/src/libp2p.go | 445 ++++++++- pyproject.toml | 6 +- run.sh | 4 +- rust/discovery/src/behaviour.rs | 2 +- rust/discovery/src/lib.rs | 18 +- rust/discovery/src/transport.rs | 3 +- rust/exo_pyo3_bindings/src/discovery.rs | 124 +++ shared/apply/apply.py | 33 +- shared/db/sqlite/connector.py | 138 ++- shared/db/sqlite/event_log_manager.py | 68 +- shared/topology.py | 35 +- shared/types/events/_events.py | 19 + shared/types/worker/common.py | 1 - shared/types/worker/ops.py | 9 - shared/types/worker/runners.py | 22 +- worker/common.py | 35 + worker/download/conftest.py | 4 +- worker/main.py | 648 +------------ worker/plan.py | 205 ++++ worker/runner/communication.py | 2 +- worker/tests/conftest.py | 201 ++-- worker/tests/constants.py | 26 + worker/tests/test_download.py | 1 + worker/tests/test_handlers/conftest.py | 70 ++ .../test_handlers/test_handlers_happy.py | 159 +++ .../tests/test_handlers/test_handlers_sad.py | 61 ++ worker/tests/test_handlers/utils.py | 18 + worker/tests/test_integration/conftest.py | 36 + .../integration_utils.py} | 0 .../test_creation.py} | 304 +----- .../tests/test_integration/test_inference.py | 256 +++++ .../test_supervisor_errors.py | 124 ++- worker/tests/test_plan/test_worker_plan.py | 540 +++++++++++ .../tests/test_plan/test_worker_plan_utils.py | 272 ++++++ worker/tests/test_runner_connection.py | 73 +- worker/tests/test_serdes.py | 2 - worker/tests/test_spinup_timeout.py | 8 +- .../{ => test_supervisor}/test_supervisor.py | 6 - worker/tests/test_worker_handlers.py | 237 ----- worker/tests/test_worker_plan.py | 913 ------------------ worker/tests/test_worker_plan_utils.py | 195 ---- worker/worker.py | 415 ++++++++ 60 files changed, 4048 insertions(+), 2857 deletions(-) create mode 100644 .github/actions/unit-test/action.yml create mode 100644 .github/workflows/e2e_test.yml create mode 100644 networking/forwarder/src/event_writer.go create mode 100644 worker/common.py create mode 100644 worker/plan.py create mode 100644 worker/tests/constants.py create mode 100644 worker/tests/test_handlers/conftest.py create mode 100644 worker/tests/test_handlers/test_handlers_happy.py create mode 100644 worker/tests/test_handlers/test_handlers_sad.py create mode 100644 worker/tests/test_handlers/utils.py create mode 100644 worker/tests/test_integration/conftest.py rename worker/tests/{test_worker_integration_utils.py => test_integration/integration_utils.py} (100%) rename worker/tests/{test_worker_integration.py => test_integration/test_creation.py} (53%) create mode 100644 worker/tests/test_integration/test_inference.py rename worker/tests/{ => test_integration}/test_supervisor_errors.py (65%) create mode 100644 worker/tests/test_plan/test_worker_plan.py create mode 100644 worker/tests/test_plan/test_worker_plan_utils.py rename worker/tests/{ => test_supervisor}/test_supervisor.py (98%) delete mode 100644 worker/tests/test_worker_handlers.py delete mode 100644 worker/tests/test_worker_plan.py delete mode 100644 worker/tests/test_worker_plan_utils.py create mode 100644 worker/worker.py diff --git a/.github/actions/format/action.yml b/.github/actions/format/action.yml index 1b43e9c4..5df1b5f4 100644 --- a/.github/actions/format/action.yml +++ b/.github/actions/format/action.yml @@ -6,5 +6,5 @@ runs: using: "composite" steps: - name: Format code - run: nix develop -c just fmt + run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just fmt shell: bash diff --git a/.github/actions/lint-check/action.yml b/.github/actions/lint-check/action.yml index f666cae9..7d69c90d 100644 --- a/.github/actions/lint-check/action.yml +++ b/.github/actions/lint-check/action.yml @@ -6,5 +6,5 @@ runs: using: "composite" steps: - name: Lint check - run: nix develop -c just lint-check + run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just lint-check shell: bash diff --git a/.github/actions/lint/action.yml b/.github/actions/lint/action.yml index 68c7eb53..05f7939c 100644 --- a/.github/actions/lint/action.yml +++ b/.github/actions/lint/action.yml @@ -6,5 +6,5 @@ runs: using: "composite" steps: - name: Lint code - run: nix develop -c just lint + run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just lint shell: bash diff --git a/.github/actions/regenerate-protobufs/action.yml b/.github/actions/regenerate-protobufs/action.yml index dfc65512..6da2a7a4 100644 --- a/.github/actions/regenerate-protobufs/action.yml +++ b/.github/actions/regenerate-protobufs/action.yml @@ -6,5 +6,5 @@ runs: using: "composite" steps: - name: Regenerate protobufs - run: nix develop -c just regenerate-protobufs + run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just regenerate-protobufs shell: bash diff --git a/.github/actions/typecheck/action.yml b/.github/actions/typecheck/action.yml index ba61737f..cd52d6e3 100644 --- a/.github/actions/typecheck/action.yml +++ b/.github/actions/typecheck/action.yml @@ -1,12 +1,12 @@ name: Type Check -description: "Run static type checker" +description: "Run type checker" runs: using: "composite" steps: - name: Run type checker run: | - nix develop -c just sync - nix develop -c just check + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check shell: bash diff --git a/.github/actions/unit-test/action.yml b/.github/actions/unit-test/action.yml new file mode 100644 index 00000000..65f5e07b --- /dev/null +++ b/.github/actions/unit-test/action.yml @@ -0,0 +1,12 @@ +name: Unit Test + +description: "Run unit tests" + +runs: + using: "composite" + steps: + - name: Run unit tests + run: | + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync-clean + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just test-fast + shell: bash diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml new file mode 100644 index 00000000..9b512e0e --- /dev/null +++ b/.github/workflows/e2e_test.yml @@ -0,0 +1,360 @@ +name: macOS System Info + +on: + workflow_dispatch: # This allows manual triggering + # push: + # branches: [ '*' ] + # tags: [ '*' ] + +jobs: + master: + runs-on: ['self-hosted', 'macOS'] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Configure git user + run: | + git config --local user.email "github-actions@users.noreply.github.com" + git config --local user.name "github-actions bot" + shell: bash + + - name: Pull LFS files + run: | + echo "Pulling Git LFS files..." + git lfs pull + shell: bash + + - name: Reset databases + run: | + if [ -d ~/.exo ]; then + rm -rf ~/.exo/*.db* + fi + + - name: Setup EXO_HOME and API_PORT + run: | + EXO_HOME=$(mktemp -d -t exo-e2e-master-XXXXXXXX) + # Generate random port (macOS compatible method) + API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1))) + echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV + echo "API_PORT=$API_PORT" >> $GITHUB_ENV + echo "Created EXO_HOME: $EXO_HOME" + echo "Generated API_PORT: $API_PORT" + echo "Verifying API_PORT is set: $API_PORT" + shell: bash + + - name: Setup Nix Environment + run: | + echo "Checking for nix installation..." + + # Check if nix binary exists directly + if [ -f /nix/var/nix/profiles/default/bin/nix ]; then + echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix" + export PATH="/nix/var/nix/profiles/default/bin:$PATH" + echo "PATH=$PATH" >> $GITHUB_ENV + nix --version + elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then + echo "Found nix profile script, sourcing..." + source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh + nix --version + elif command -v nix >/dev/null 2>&1; then + echo "Nix already in PATH" + nix --version + else + echo "Nix not found. Debugging info:" + echo "Contents of /nix/var/nix/profiles/default/:" + ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found" + echo "Contents of /nix/var/nix/profiles/default/bin/:" + ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found" + exit 1 + fi + shell: bash + + - name: Print macOS system information + run: | + echo "=== macOS System Information ===" + echo "OS Version:" + sw_vers + + echo -e "\n=== Memory Information ===" + system_profiler SPMemoryDataType + + echo -e "\n=== Memory Usage Summary ===" + vm_stat | perl -ne '/page size of (\d+)/ and $size=$1; /Pages free: (\d+)/ and printf "Free Memory: %.2f GB\n", $1 * $size / 1024 / 1024 / 1024' + top -l 1 -s 0 | grep PhysMem + + echo -e "\n=== CPU Information ===" + sysctl -n machdep.cpu.brand_string + system_profiler SPHardwareDataType | grep -E "Cores|Processors" + + echo -e "\n=== Disk Space ===" + df -h / + + # - name: Setup Hugging Face token + # run: | + # mkdir -p ~/.cache/huggingface + # echo "${{ secrets.HF_TOKEN }}" > ~/.cache/huggingface/token + + - name: Sync dependencies + run: | + echo "Running just sync-clean to ensure clean dependencies..." + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync-clean + shell: bash + + - name: Build forwarder + run: | + echo "Building Go forwarder binary..." + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just build-forwarder + shell: bash + + - name: Start node (master) + run: | + echo "Starting master node with debug enabled..." + echo "Environment check - API_PORT: '$API_PORT'" + echo "Environment check - EXO_HOME: '$EXO_HOME'" + if [ -z "$API_PORT" ]; then + echo "ERROR: API_PORT is not set!" + exit 1 + fi + # Run with Python unbuffered output and maximum debug level + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_HOME=$EXO_HOME API_PORT=$API_PORT PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run master/main.py" > /tmp/master_node.log 2>&1 & + MASTER_PID=$! + echo "Started master node in background with PID: $MASTER_PID" + echo "Log file: /tmp/master_node.log" + + echo "Starting worker node..." + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_HOME=$EXO_HOME PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run worker/main.py" > /tmp/worker_node.log 2>&1 & + WORKER_PID=$! + echo "Started worker node in background with PID: $WORKER_PID" + echo "Log file: /tmp/worker_node.log" + + for i in {1..30}; do + echo "Attempt $i: Checking if master node is ready..." + if curl -s http://localhost:$API_PORT/state > /dev/null 2>&1; then + echo "Master node is ready!" + break + fi + if [ $i -eq 30 ]; then + echo "Master node failed to start within 30 seconds. Checking logs..." + echo "=== Master node log ===" + cat /tmp/master_node.log || echo "No master log file found" + echo "=== Worker node log ===" + cat /tmp/worker_node.log || echo "No worker log file found" + exit 1 + fi + sleep 1 + done + + # wait for master to have a COMPLETE or FAILED task in the state + for i in {1..30}; do + if curl -s http://localhost:$API_PORT/state | jq -r '.tasks | any(.task_status == "COMPLETE" or .task_status == "FAILED")' > 0; then + echo "Master node has a COMPLETE or FAILED task in the state" + break + fi + sleep 1 + done + + echo "=== Master node log ===" + cat /tmp/master_node.log || echo "No master log file found" + echo "=== Worker node log ===" + cat /tmp/worker_node.log || echo "No worker log file found" + + - name: Cleanup EXO_HOME + run: | + echo "Cleaning up EXO_HOME: $EXO_HOME" + rm -rf "$EXO_HOME" + shell: bash + if: always() + + worker: + runs-on: ['self-hosted', 'macOS'] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Configure git user + run: | + git config --local user.email "github-actions@users.noreply.github.com" + git config --local user.name "github-actions bot" + shell: bash + + - name: Pull LFS files + run: | + echo "Pulling Git LFS files..." + git lfs pull + shell: bash + + - name: Reset databases + run: | + if [ -d ~/.exo ]; then + rm -rf ~/.exo/*.db* + fi + + - name: Setup EXO_HOME and API_PORT + run: | + EXO_HOME=$(mktemp -d -t exo-e2e-worker-XXXXXXXX) + # Generate random port (macOS compatible method) + API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1))) + echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV + echo "API_PORT=$API_PORT" >> $GITHUB_ENV + echo "Created EXO_HOME: $EXO_HOME" + echo "Generated API_PORT: $API_PORT" + echo "Verifying API_PORT is set: $API_PORT" + shell: bash + + - name: Setup Nix Environment + run: | + echo "Checking for nix installation..." + + # Check if nix binary exists directly + if [ -f /nix/var/nix/profiles/default/bin/nix ]; then + echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix" + export PATH="/nix/var/nix/profiles/default/bin:$PATH" + echo "PATH=$PATH" >> $GITHUB_ENV + nix --version + elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then + echo "Found nix profile script, sourcing..." + source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh + nix --version + elif command -v nix >/dev/null 2>&1; then + echo "Nix already in PATH" + nix --version + else + echo "Nix not found. Debugging info:" + echo "Contents of /nix/var/nix/profiles/default/:" + ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found" + echo "Contents of /nix/var/nix/profiles/default/bin/:" + ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found" + exit 1 + fi + shell: bash + + - name: Print macOS system information + run: | + echo "=== macOS System Information ===" + echo "OS Version:" + sw_vers + + echo -e "\n=== Memory Information ===" + system_profiler SPMemoryDataType + + echo -e "\n=== Memory Usage Summary ===" + vm_stat | perl -ne '/page size of (\d+)/ and $size=$1; /Pages free: (\d+)/ and printf "Free Memory: %.2f GB\n", $1 * $size / 1024 / 1024 / 1024' + top -l 1 -s 0 | grep PhysMem + + echo -e "\n=== CPU Information ===" + sysctl -n machdep.cpu.brand_string + system_profiler SPHardwareDataType | grep -E "Cores|Processors" + + echo -e "\n=== Disk Space ===" + df -h / + + # - name: Setup Hugging Face token + # run: | + # mkdir -p ~/.cache/huggingface + # echo "${{ secrets.HF_TOKEN }}" > ~/.cache/huggingface/token + + - name: Sync dependencies + run: | + echo "Running just sync-clean to ensure clean dependencies..." + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync-clean + shell: bash + + - name: Build forwarder + run: | + echo "Building Go forwarder binary..." + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just build-forwarder + shell: bash + + - name: Start node (replica) + run: | + echo "Starting master node with debug enabled..." + echo "Environment check - API_PORT: '$API_PORT'" + echo "Environment check - EXO_HOME: '$EXO_HOME'" + if [ -z "$API_PORT" ]; then + echo "ERROR: API_PORT is not set!" + exit 1 + fi + # Run with Python unbuffered output and maximum debug level + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_RUN_AS_REPLICA=1 EXO_HOME=$EXO_HOME API_PORT=$API_PORT PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run master/main.py" > /tmp/master_node.log 2>&1 & + MASTER_PID=$! + echo "Started master node in background with PID: $MASTER_PID" + echo "Log file: /tmp/master_node.log" + + echo "Starting worker node..." + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_HOME=$EXO_HOME PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run worker/main.py" > /tmp/worker_node.log 2>&1 & + WORKER_PID=$! + echo "Started worker node in background with PID: $WORKER_PID" + echo "Log file: /tmp/worker_node.log" + + echo "Waiting for master node to start on port $API_PORT..." + # Wait for the master node to be ready (up to 30 seconds) + for i in {1..30}; do + echo "Attempt $i: Checking if master node is ready..." + if curl -s http://localhost:$API_PORT/state > /dev/null 2>&1; then + echo "Master node is ready!" + break + fi + if [ $i -eq 30 ]; then + echo "Master node failed to start within 30 seconds. Checking logs..." + echo "=== Master node log ===" + cat /tmp/master_node.log || echo "No master log file found" + echo "=== Worker node log ===" + cat /tmp/worker_node.log || echo "No worker log file found" + exit 1 + fi + sleep 1 + done + + resp=$(curl -X POST http://localhost:$API_PORT/instance -H "Content-Type: application/json" -d '{"model_id": "llama-3.2:1b"}') + echo "Response: $resp" + instance_id=$(echo $resp | jq -r '.instance_id') + echo "Instance ID: $instance_id" + + for i in {1..50}; do + resp=$(curl -s -w "%{http_code}" -X GET http://localhost:$API_PORT/instance/$instance_id -H "Content-Type: application/json") + http_code="${resp: -3}" + response_body="${resp%???}" + echo "HTTP Code: $http_code" + echo "Response: $response_body" + + if [ "$http_code" == "200" ]; then + instance_status=$(echo $response_body | jq -r '.instance_type') + if [ "$instance_status" == "ACTIVE" ]; then + echo "Instance is ready" + break + fi + elif [ "$http_code" == "404" ]; then + echo "Instance not yet created, waiting..." + else + echo "Unexpected HTTP status: $http_code" + fi + sleep 1 + done + + resp=$(curl http://localhost:$API_PORT/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "llama-3.2:1b", "messages": [{"role": "user", "content": "What is the meaning of exo?"}], "temperature": 0.7}') + echo "Response: $resp" + + resp=$(curl -X DELETE http://localhost:$API_PORT/instance/$instance_id -H "Content-Type: application/json") + echo "Response: $resp" + + echo "=== Master node log ===" + cat /tmp/master_node.log || echo "No master log file found" + echo "=== Worker node log ===" + cat /tmp/worker_node.log || echo "No worker log file found" + + kill $MASTER_PID + kill $WORKER_PID + + - name: Cleanup EXO_HOME + run: | + echo "Cleaning up EXO_HOME: $EXO_HOME" + rm -rf "$EXO_HOME" + shell: bash + if: always() diff --git a/.github/workflows/pipeline.yml b/.github/workflows/pipeline.yml index e2834848..71ba82f8 100644 --- a/.github/workflows/pipeline.yml +++ b/.github/workflows/pipeline.yml @@ -12,9 +12,12 @@ on: jobs: typecheck: - runs-on: ubuntu-22.04 + runs-on: ['self-hosted', 'macOS'] steps: - - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true - name: Configure git user run: | @@ -22,23 +25,54 @@ jobs: git config --local user.name "github-actions bot" shell: bash - - uses: cachix/install-nix-action@v31 - with: - github_access_token: ${{ secrets.GITHUB_TOKEN }} + - name: Pull LFS files + run: | + echo "Pulling Git LFS files..." + git lfs pull + shell: bash + + - name: Setup Nix Environment + run: | + echo "Checking for nix installation..." + + # Check if nix binary exists directly + if [ -f /nix/var/nix/profiles/default/bin/nix ]; then + echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix" + export PATH="/nix/var/nix/profiles/default/bin:$PATH" + echo "PATH=$PATH" >> $GITHUB_ENV + nix --version + elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then + echo "Found nix profile script, sourcing..." + source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh + nix --version + elif command -v nix >/dev/null 2>&1; then + echo "Nix already in PATH" + nix --version + else + echo "Nix not found. Debugging info:" + echo "Contents of /nix/var/nix/profiles/default/:" + ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found" + echo "Contents of /nix/var/nix/profiles/default/bin/:" + ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found" + exit 1 + fi + shell: bash - uses: ./.github/actions/typecheck ci: needs: typecheck - runs-on: ubuntu-22.04 + runs-on: ['self-hosted', 'macOS'] permissions: contents: read env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v4 with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} + lfs: true - name: Configure git user run: | @@ -46,12 +80,67 @@ jobs: git config --local user.name "github-actions bot" shell: bash - - uses: cachix/install-nix-action@v31 - with: - github_access_token: ${{ secrets.GITHUB_TOKEN }} + - name: Pull LFS files + run: | + echo "Pulling Git LFS files..." + git lfs pull + shell: bash + + - name: Setup EXO_HOME and API_PORT + run: | + EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX) + # Generate random port (macOS compatible method) + API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1))) + echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV + echo "API_PORT=$API_PORT" >> $GITHUB_ENV + echo "Created EXO_HOME: $EXO_HOME" + echo "Generated API_PORT: $API_PORT" + shell: bash + + - name: Setup Nix Environment + run: | + echo "Checking for nix installation..." + + # Check if nix binary exists directly + if [ -f /nix/var/nix/profiles/default/bin/nix ]; then + echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix" + export PATH="/nix/var/nix/profiles/default/bin:$PATH" + echo "PATH=$PATH" >> $GITHUB_ENV + nix --version + elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then + echo "Found nix profile script, sourcing..." + source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh + nix --version + elif command -v nix >/dev/null 2>&1; then + echo "Nix already in PATH" + nix --version + else + echo "Nix not found. Debugging info:" + echo "Contents of /nix/var/nix/profiles/default/:" + ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found" + echo "Contents of /nix/var/nix/profiles/default/bin/:" + ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found" + exit 1 + fi + shell: bash + + - name: Build forwarder + run: | + echo "Building Go forwarder binary..." + nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just build-forwarder + shell: bash - uses: ./.github/actions/verify-clean with: step: regenerate-protobufs - - uses: ./.github/actions/lint-check \ No newline at end of file + - uses: ./.github/actions/lint-check + + - uses: ./.github/actions/unit-test + + - name: Cleanup EXO_HOME + run: | + echo "Cleaning up EXO_HOME: $EXO_HOME" + rm -rf "$EXO_HOME" + shell: bash + if: always() diff --git a/engines/mlx/utils_mlx.py b/engines/mlx/utils_mlx.py index 1b77413f..1dde2e14 100644 --- a/engines/mlx/utils_mlx.py +++ b/engines/mlx/utils_mlx.py @@ -1,6 +1,7 @@ import asyncio import concurrent.futures import os +import resource from asyncio import AbstractEventLoop from typing import Any, Callable @@ -18,6 +19,8 @@ from shared.types.worker.shards import ShardMetadata from worker.download.download_utils import build_model_path from worker.runner.communication import runner_print +# Needed for 8 bit model +resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) def mx_barrier(): mx.eval( # type: ignore @@ -86,6 +89,7 @@ def shard_and_load(model_shard_meta: ShardMetadata) -> tuple[nn.Module, Tokenize tokenizer = load_tokenizer(model_path) assert isinstance(tokenizer, TokenizerWrapper) model = auto_parallel(model, model_shard_meta) + mx.eval(model.parameters()) # type: ignore # Synchronize processes before generation to avoid timeout mx_barrier() diff --git a/justfile b/justfile index 5b92d3c4..871eec6d 100644 --- a/justfile +++ b/justfile @@ -19,6 +19,9 @@ lint-check: test: uv run pytest master worker shared engines/* +test-fast: + uv run pytest master shared engines/* + check: uv run basedpyright --project pyproject.toml diff --git a/master/discovery_supervisor.py b/master/discovery_supervisor.py index 440d512b..08f2c072 100644 --- a/master/discovery_supervisor.py +++ b/master/discovery_supervisor.py @@ -48,7 +48,10 @@ class DiscoverySupervisor: local_multiaddr = Multiaddr(address=str(e.local_addr)) send_back_multiaddr = Multiaddr(address=str(e.send_back_addr)) connection_profile = None - + + if send_back_multiaddr.ipv4_address == local_multiaddr.ipv4_address: + return + topology_edge_created = TopologyEdgeCreated(edge=Connection( local_node_id=local_node_id, send_back_node_id=send_back_node_id, @@ -56,7 +59,7 @@ class DiscoverySupervisor: send_back_multiaddr=send_back_multiaddr, connection_profile=connection_profile )) - self.logger.error( + self.logger.info( msg=f"CONNECTED CALLBACK: {local_node_id} -> {send_back_node_id}, {local_multiaddr} -> {send_back_multiaddr}") await self.global_events.append_events( [topology_edge_created], diff --git a/master/forwarder_supervisor.py b/master/forwarder_supervisor.py index 979d362e..4e7fa918 100644 --- a/master/forwarder_supervisor.py +++ b/master/forwarder_supervisor.py @@ -111,6 +111,7 @@ class ForwarderSupervisor: env_vars["FORWARDER_NODE_ID"] = str(self.node_id) self._process = await asyncio.create_subprocess_exec( str(self._binary_path), + "--events-db", str(EXO_WORKER_EVENT_DB), f'{pairs}', stdout=None, stderr=None, diff --git a/master/main.py b/master/main.py index 2ce5ed8b..0b991e96 100644 --- a/master/main.py +++ b/master/main.py @@ -9,7 +9,8 @@ from typing import List from exo_pyo3_bindings import Keypair from master.api import start_fastapi_server -from master.discovery_supervisor import DiscoverySupervisor + +# from master.discovery_supervisor import DiscoverySupervisor from master.election_callback import ElectionCallbacks from master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor from master.placement import get_instance_placements, get_transition_events @@ -45,13 +46,13 @@ class Master: self.command_buffer = command_buffer self.global_events = global_events self.worker_events = worker_events - self.discovery_supervisor = DiscoverySupervisor( - node_id_keypair, - node_id, - # TODO: needs to be more general for when we have master election - worker_events if os.getenv('EXO_RUN_AS_REPLICA') in set(['TRUE', 'true', '1']) else global_events, - logger - ) + # self.discovery_supervisor = DiscoverySupervisor( + # node_id_keypair, + # node_id, + # # TODO: needs to be more general for when we have master election + # worker_events if os.getenv('EXO_RUN_AS_REPLICA') in set(['TRUE', 'true', '1']) else global_events, + # logger + # ) self.forwarder_supervisor = ForwarderSupervisor( self.node_id, forwarder_binary_path=forwarder_binary_path, @@ -116,7 +117,7 @@ class Master: await self.event_log_for_writes.append_events(next_events, origin=self.node_id) # 2. get latest events - events = await self.event_log_for_reads.get_events_since(self.state.last_event_applied_idx) + events = await self.event_log_for_reads.get_events_since(self.state.last_event_applied_idx, ignore_no_op_events=True) if len(events) == 0: await asyncio.sleep(0.01) return @@ -157,7 +158,7 @@ class Master: async def main(): logger = logging.getLogger('master_logger') - logger.setLevel(logging.DEBUG) + logger.setLevel(logging.INFO) if not logger.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) diff --git a/master/tests/test_topology.py b/master/tests/test_topology.py index 151ef0c3..9172adbb 100644 --- a/master/tests/test_topology.py +++ b/master/tests/test_topology.py @@ -114,8 +114,7 @@ def test_remove_connection_still_connected(topology: Topology, node_profile: Nod topology.remove_connection(connection) # assert - with pytest.raises(IndexError): - topology.get_connection_profile(connection) + assert topology.get_connection_profile(connection) is None def test_remove_connection_bridge(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection): @@ -129,7 +128,9 @@ def test_remove_connection_bridge(topology: Topology, node_profile: NodePerforma topology.add_node(Node(node_id=master_id, node_profile=node_profile)) topology.add_node(Node(node_id=node_a_id, node_profile=node_profile)) topology.add_node(Node(node_id=node_b_id, node_profile=node_profile)) - + + topology.set_master_node_id(master_id) + connection_master_to_a = Connection( local_node_id=master_id, send_back_node_id=node_a_id, @@ -157,11 +158,8 @@ def test_remove_connection_bridge(topology: Topology, node_profile: NodePerforma assert len(remaining_nodes) == 1 assert remaining_nodes[0].node_id == master_id - with pytest.raises(KeyError): - topology.get_node_profile(node_a_id) - - with pytest.raises(KeyError): - topology.get_node_profile(node_b_id) + assert topology.get_node_profile(node_a_id) is None + assert topology.get_node_profile(node_b_id) is None def test_remove_node_still_connected(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection): @@ -174,8 +172,7 @@ def test_remove_node_still_connected(topology: Topology, node_profile: NodePerfo topology.remove_node(connection.local_node_id) # assert - with pytest.raises(KeyError): - topology.get_node_profile(connection.local_node_id) + assert topology.get_node_profile(connection.local_node_id) is None def test_list_nodes(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection): diff --git a/networking/forwarder/go.mod b/networking/forwarder/go.mod index b7100a6a..8c3a2aae 100644 --- a/networking/forwarder/go.mod +++ b/networking/forwarder/go.mod @@ -1,16 +1,17 @@ module forwarder -go 1.23 +go 1.23.8 toolchain go1.24.3 replace forwarder/src => ./src require ( - github.com/google/uuid v1.6.0 - github.com/libp2p/go-libp2p v0.39.1 + github.com/google/uuid v1.6.0 + github.com/libp2p/go-libp2p v0.42.1 github.com/libp2p/go-libp2p-pubsub v0.14.2 github.com/mattn/go-sqlite3 v1.14.28 + github.com/multiformats/go-multiaddr v0.16.0 github.com/stretchr/testify v1.10.0 ) @@ -18,110 +19,99 @@ require ( github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/containerd/cgroups v1.1.0 // indirect - github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect - github.com/docker/go-units v0.5.0 // indirect - github.com/elastic/gosigar v0.14.3 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/flynn/noise v1.1.0 // indirect github.com/francoispqt/gojay v1.2.13 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/gopacket v1.1.19 // indirect - github.com/google/pprof v0.0.0-20250202011525-fc3143867406 // indirect + github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/huin/goupnp v1.3.0 // indirect github.com/ipfs/go-cid v0.5.0 // indirect - github.com/ipfs/go-log/v2 v2.5.1 // indirect + github.com/ipfs/go-log/v2 v2.6.0 // indirect github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect - github.com/klauspost/compress v1.17.11 // indirect - github.com/klauspost/cpuid/v2 v2.2.9 // indirect - github.com/koron/go-ssdp v0.0.5 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect + github.com/koron/go-ssdp v0.0.6 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect github.com/libp2p/go-flow-metrics v0.2.0 // indirect github.com/libp2p/go-libp2p-asn-util v0.4.1 // indirect github.com/libp2p/go-msgio v0.3.0 // indirect - github.com/libp2p/go-nat v0.2.0 // indirect github.com/libp2p/go-netroute v0.2.2 // indirect github.com/libp2p/go-reuseport v0.4.0 // indirect - github.com/libp2p/go-yamux/v4 v4.0.2 // indirect + github.com/libp2p/go-yamux/v5 v5.0.1 // indirect github.com/libp2p/zeroconf/v2 v2.2.0 // indirect github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/miekg/dns v1.1.63 // indirect + github.com/miekg/dns v1.1.66 // indirect github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b // indirect github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect github.com/minio/sha256-simd v1.0.1 // indirect github.com/mr-tron/base58 v1.2.0 // indirect github.com/multiformats/go-base32 v0.1.0 // indirect github.com/multiformats/go-base36 v0.2.0 // indirect - github.com/multiformats/go-multiaddr v0.14.0 // indirect github.com/multiformats/go-multiaddr-dns v0.4.1 // indirect github.com/multiformats/go-multiaddr-fmt v0.1.0 // indirect github.com/multiformats/go-multibase v0.2.0 // indirect - github.com/multiformats/go-multicodec v0.9.0 // indirect + github.com/multiformats/go-multicodec v0.9.1 // indirect github.com/multiformats/go-multihash v0.2.3 // indirect - github.com/multiformats/go-multistream v0.6.0 // indirect + github.com/multiformats/go-multistream v0.6.1 // indirect github.com/multiformats/go-varint v0.0.7 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/onsi/ginkgo/v2 v2.22.2 // indirect - github.com/opencontainers/runtime-spec v1.2.0 // indirect + github.com/onsi/ginkgo/v2 v2.23.4 // indirect github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect github.com/pion/datachannel v1.5.10 // indirect github.com/pion/dtls/v2 v2.2.12 // indirect - github.com/pion/dtls/v3 v3.0.4 // indirect - github.com/pion/ice/v2 v2.3.37 // indirect - github.com/pion/ice/v4 v4.0.6 // indirect - github.com/pion/interceptor v0.1.37 // indirect + github.com/pion/dtls/v3 v3.0.6 // indirect + github.com/pion/ice/v4 v4.0.10 // indirect + github.com/pion/interceptor v0.1.40 // indirect github.com/pion/logging v0.2.3 // indirect - github.com/pion/mdns v0.0.12 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect github.com/pion/rtcp v1.2.15 // indirect - github.com/pion/rtp v1.8.11 // indirect - github.com/pion/sctp v1.8.35 // indirect - github.com/pion/sdp/v3 v3.0.10 // indirect - github.com/pion/srtp/v3 v3.0.4 // indirect + github.com/pion/rtp v1.8.19 // indirect + github.com/pion/sctp v1.8.39 // indirect + github.com/pion/sdp/v3 v3.0.13 // indirect + github.com/pion/srtp/v3 v3.0.6 // indirect github.com/pion/stun v0.6.1 // indirect github.com/pion/stun/v3 v3.0.0 // indirect github.com/pion/transport/v2 v2.2.10 // indirect github.com/pion/transport/v3 v3.0.7 // indirect - github.com/pion/turn/v2 v2.1.6 // indirect - github.com/pion/turn/v4 v4.0.0 // indirect - github.com/pion/webrtc/v4 v4.0.8 // indirect - github.com/pkg/errors v0.9.1 // indirect + github.com/pion/turn/v4 v4.0.2 // indirect + github.com/pion/webrtc/v4 v4.1.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_golang v1.20.5 // indirect - github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.62.0 // indirect - github.com/prometheus/procfs v0.15.1 // indirect + github.com/prometheus/client_golang v1.22.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.64.0 // indirect + github.com/prometheus/procfs v0.16.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect - github.com/quic-go/quic-go v0.49.0 // indirect + github.com/quic-go/quic-go v0.52.0 // indirect github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 // indirect - github.com/raulk/go-watchdog v1.3.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect - go.uber.org/dig v1.18.0 // indirect - go.uber.org/fx v1.23.0 // indirect - go.uber.org/mock v0.5.0 // indirect + go.uber.org/automaxprocs v1.6.0 // indirect + go.uber.org/dig v1.19.0 // indirect + go.uber.org/fx v1.24.0 // indirect + go.uber.org/mock v0.5.2 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/crypto v0.32.0 // indirect - golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect - golang.org/x/mod v0.23.0 // indirect - golang.org/x/net v0.34.0 // indirect - golang.org/x/sync v0.11.0 // indirect - golang.org/x/sys v0.30.0 // indirect - golang.org/x/text v0.22.0 // indirect - golang.org/x/tools v0.29.0 // indirect - google.golang.org/protobuf v1.36.4 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 // indirect + golang.org/x/mod v0.25.0 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sync v0.15.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect + golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.34.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - lukechampine.com/blake3 v1.3.0 // indirect + lukechampine.com/blake3 v1.4.1 // indirect ) // Remember to run `go mod tidy` after adding dependencies. diff --git a/networking/forwarder/go.sum b/networking/forwarder/go.sum index 75e179a9..5ba5ce9e 100644 --- a/networking/forwarder/go.sum +++ b/networking/forwarder/go.sum @@ -9,8 +9,6 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= -github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -20,33 +18,18 @@ github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBT github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE= -github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM= -github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c h1:pFUpOrbxDR6AkioZ1ySsx5yxlDQZ8stG2b88gTPxgJU= github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c/go.mod h1:6UhI8N9EjYm1c2odKpFpAYeR8dsBeM7PtzQhRgxRr9U= -github.com/decred/dcrd/crypto/blake256 v1.0.1 h1:7PltbUIQB7u/FfZ39+DGa/ShuMyJ5ilcvdfma9wOH6Y= -github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= -github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= -github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/decred/dcrd/crypto/blake256 v1.1.0 h1:zPMNGQCm0g4QTY27fOCorQW7EryeQ/U0x++OzVrdms8= +github.com/decred/dcrd/crypto/blake256 v1.1.0/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= -github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo= -github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= @@ -60,12 +43,7 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= -github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= -github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -76,18 +54,16 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20250202011525-fc3143867406 h1:wlQI2cYY0BsWmmPPAnxfQ8SDW0S3Jasn+4B8kXFxprg= -github.com/google/pprof v0.0.0-20250202011525-fc3143867406/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a h1://KbezygeMJZCSHH+HgUZiTeSoiuFspbMg1ge+eFj18= +github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= @@ -103,8 +79,8 @@ github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc= github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= github.com/ipfs/go-cid v0.5.0 h1:goEKKhaGm0ul11IHA7I6p1GmKz8kEYniqFopaB5Otwg= github.com/ipfs/go-cid v0.5.0/go.mod h1:0L7vmeNXpQpUS9vt+yEARkJ8rOg43DF3iPgn4GIN0mk= -github.com/ipfs/go-log/v2 v2.5.1 h1:1XdUzF7048prq4aBjDQQ4SL5RxftpRGdXhNRwKSAlcY= -github.com/ipfs/go-log/v2 v2.5.1/go.mod h1:prSpmC1Gpllc9UYWxDiZDreBYw7zp4Iqp1kOLU9U5UI= +github.com/ipfs/go-log/v2 v2.6.0 h1:2Nu1KKQQ2ayonKp4MPo6pXCjqw1ULc9iohRqWV5EYqg= +github.com/ipfs/go-log/v2 v2.6.0/go.mod h1:p+Efr3qaY5YXpx9TX7MoLCSEZX5boSWj9wh86P5HJa8= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jbenet/go-temp-err-catcher v0.1.0 h1:zpb3ZH6wIE8Shj2sKS+khgRvf7T7RABoLk/+KKHggpk= @@ -112,15 +88,14 @@ github.com/jbenet/go-temp-err-catcher v0.1.0/go.mod h1:0kJRvmDZXNMIiJirNPEYfhpPw github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= -github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= -github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= -github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= -github.com/koron/go-ssdp v0.0.5 h1:E1iSMxIs4WqxTbIBLtmNBeOOC+1sCIXQeqTWVnpmwhk= -github.com/koron/go-ssdp v0.0.5/go.mod h1:Qm59B7hpKpDqfyRNWRNr00jGwLdXjDyZh6y7rH6VS0w= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/koron/go-ssdp v0.0.6 h1:Jb0h04599eq/CY7rB5YEqPS83HmRfHP2azkxMN2rFtU= +github.com/koron/go-ssdp v0.0.6/go.mod h1:0R9LfRJGek1zWTjN3JUNlm5INCDYGpRDfAptnct63fI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -134,8 +109,8 @@ github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6 github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg= github.com/libp2p/go-flow-metrics v0.2.0 h1:EIZzjmeOE6c8Dav0sNv35vhZxATIXWZg6j/C08XmmDw= github.com/libp2p/go-flow-metrics v0.2.0/go.mod h1:st3qqfu8+pMfh+9Mzqb2GTiwrAGjIPszEjZmtksN8Jc= -github.com/libp2p/go-libp2p v0.39.1 h1:1Ur6rPCf3GR+g8jkrnaQaM0ha2IGespsnNlCqJLLALE= -github.com/libp2p/go-libp2p v0.39.1/go.mod h1:3zicI8Lp7Isun+Afo/JOACUbbJqqR2owK6RQWFsVAbI= +github.com/libp2p/go-libp2p v0.42.1 h1:Rt8+5thie729NQk1gx1h/2t/+VIafWcqR1I+Kvw+UTg= +github.com/libp2p/go-libp2p v0.42.1/go.mod h1:4NGcjbD9OIvFiSRb0XueCO19zJ4kSPK5vkyyOUYmMro= github.com/libp2p/go-libp2p-asn-util v0.4.1 h1:xqL7++IKD9TBFMgnLPZR6/6iYhawHKHl950SO9L6n94= github.com/libp2p/go-libp2p-asn-util v0.4.1/go.mod h1:d/NI6XZ9qxw67b4e+NgpQexCIiFYJjErASrYW4PFDN8= github.com/libp2p/go-libp2p-pubsub v0.14.2 h1:nT5lFHPQOFJcp9CW8hpKtvbpQNdl2udJuzLQWbgRum8= @@ -144,21 +119,18 @@ github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUI github.com/libp2p/go-libp2p-testing v0.12.0/go.mod h1:KcGDRXyN7sQCllucn1cOOS+Dmm7ujhfEyXQL5lvkcPg= github.com/libp2p/go-msgio v0.3.0 h1:mf3Z8B1xcFN314sWX+2vOTShIE0Mmn2TXn3YCUQGNj0= github.com/libp2p/go-msgio v0.3.0/go.mod h1:nyRM819GmVaF9LX3l03RMh10QdOroF++NBbxAb0mmDM= -github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= -github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFPuZ8= github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= -github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU= -github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o= +github.com/libp2p/go-yamux/v5 v5.0.1 h1:f0WoX/bEF2E8SbE4c/k1Mo+/9z0O4oC/hWEA+nfYRSg= +github.com/libp2p/go-yamux/v5 v5.0.1/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU= github.com/libp2p/zeroconf/v2 v2.2.0 h1:Cup06Jv6u81HLhIj1KasuNM/RHHrJ8T7wOTS4+Tv53Q= github.com/libp2p/zeroconf/v2 v2.2.0/go.mod h1:fuJqLnUwZTshS3U/bMRJ3+ow/v9oid1n0DmyYyNO1Xs= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd/go.mod h1:QuCEs1Nt24+FYQEqAAncTDPJIuGs+LxK1MCiFL25pMU= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= @@ -166,8 +138,8 @@ github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4= -github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY= -github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs= +github.com/miekg/dns v1.1.66 h1:FeZXOS3VCVsKnEAd+wBkjMC3D2K+ww66Cq3VnCINuJE= +github.com/miekg/dns v1.1.66/go.mod h1:jGFzBsSNbJw6z1HYut1RKBKHA9PBdxeHrZG8J+gC2WE= github.com/mikioh/tcp v0.0.0-20190314235350-803a9b46060c h1:bzE/A84HN25pxAuk9Eej1Kz9OUelF97nAc82bDquQI8= github.com/mikioh/tcp v0.0.0-20190314235350-803a9b46060c/go.mod h1:0SQS9kMwD2VsyFEB++InYyBJroV/FRmBgcydeSUcJms= github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b h1:z78hV3sbSMAUoyUMM0I83AUIT6Hu17AWfgjzIbtrYFc= @@ -188,34 +160,31 @@ github.com/multiformats/go-base32 v0.1.0/go.mod h1:Kj3tFY6zNr+ABYMqeUNeGvkIC/UYg github.com/multiformats/go-base36 v0.2.0 h1:lFsAbNOGeKtuKozrtBsAkSVhv1p9D0/qedU9rQyccr0= github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a1UV0xHgWc0hkp4= github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= -github.com/multiformats/go-multiaddr v0.14.0 h1:bfrHrJhrRuh/NXH5mCnemjpbGjzRw/b+tJFOD41g2tU= -github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4= +github.com/multiformats/go-multiaddr v0.16.0 h1:oGWEVKioVQcdIOBlYM8BH1rZDWOGJSqr9/BKl6zQ4qc= +github.com/multiformats/go-multiaddr v0.16.0/go.mod h1:JSVUmXDjsVFiW7RjIFMP7+Ev+h1DTbiJgVeTV/tcmP0= github.com/multiformats/go-multiaddr-dns v0.4.1 h1:whi/uCLbDS3mSEUMb1MsoT4uzUeZB0N32yzufqS0i5M= github.com/multiformats/go-multiaddr-dns v0.4.1/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E= github.com/multiformats/go-multiaddr-fmt v0.1.0/go.mod h1:hGtDIW4PU4BqJ50gW2quDuPVjyWNZxToGUh/HwTZYJo= github.com/multiformats/go-multibase v0.2.0 h1:isdYCVLvksgWlMW9OZRYJEa9pZETFivncJHmHnnd87g= github.com/multiformats/go-multibase v0.2.0/go.mod h1:bFBZX4lKCA/2lyOFSAoKH5SS6oPyjtnzK/XTFDPkNuk= -github.com/multiformats/go-multicodec v0.9.0 h1:pb/dlPnzee/Sxv/j4PmkDRxCOi3hXTz3IbPKOXWJkmg= -github.com/multiformats/go-multicodec v0.9.0/go.mod h1:L3QTQvMIaVBkXOXXtVmYE+LI16i14xuaojr/H7Ai54k= +github.com/multiformats/go-multicodec v0.9.1 h1:x/Fuxr7ZuR4jJV4Os5g444F7xC4XmyUaT/FWtE+9Zjo= +github.com/multiformats/go-multicodec v0.9.1/go.mod h1:LLWNMtyV5ithSBUo3vFIMaeDy+h3EbkMTek1m+Fybbo= github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew= github.com/multiformats/go-multihash v0.2.3 h1:7Lyc8XfX/IY2jWb/gI7JP+o7JEq9hOa7BFvVU9RSh+U= github.com/multiformats/go-multihash v0.2.3/go.mod h1:dXgKXCXjBzdscBLk9JkjINiEsCKRVch90MdaGiKsvSM= -github.com/multiformats/go-multistream v0.6.0 h1:ZaHKbsL404720283o4c/IHQXiS6gb8qAN5EIJ4PN5EA= -github.com/multiformats/go-multistream v0.6.0/go.mod h1:MOyoG5otO24cHIg8kf9QW2/NozURlkP/rvi2FQJyCPg= +github.com/multiformats/go-multistream v0.6.1 h1:4aoX5v6T+yWmc2raBHsTvzmFhOI8WVOer28DeBBEYdQ= +github.com/multiformats/go-multistream v0.6.1/go.mod h1:ksQf6kqHAb6zIsyw7Zm+gAuVo57Qbq84E27YlYqavqw= github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU= -github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk= -github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= -github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= -github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= -github.com/opencontainers/runtime-spec v1.2.0 h1:z97+pHb3uELt/yiAWD691HNHQIF07bE7dzrbT927iTk= -github.com/opencontainers/runtime-spec v1.2.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= +github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= +github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU= +github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= @@ -224,33 +193,29 @@ github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oL github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U= -github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg= -github.com/pion/ice/v2 v2.3.37 h1:ObIdaNDu1rCo7hObhs34YSBcO7fjslJMZV0ux+uZWh0= -github.com/pion/ice/v2 v2.3.37/go.mod h1:mBF7lnigdqgtB+YHkaY/Y6s6tsyRyo4u4rPGRuOjUBQ= -github.com/pion/ice/v4 v4.0.6 h1:jmM9HwI9lfetQV/39uD0nY4y++XZNPhvzIPCb8EwxUM= -github.com/pion/ice/v4 v4.0.6/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= -github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI= -github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y= +github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E= +github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU= +github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4= +github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= +github.com/pion/interceptor v0.1.40 h1:e0BjnPcGpr2CFQgKhrQisBU7V3GXK6wrfYrGYaU6Jq4= +github.com/pion/interceptor v0.1.40/go.mod h1:Z6kqH7M/FYirg3frjGJ21VLSRJGBXB/KqaTIrdqnOic= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI= github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90= -github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= -github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= -github.com/pion/rtp v1.8.11 h1:17xjnY5WO5hgO6SD3/NTIUPvSFw/PbLsIJyz1r1yNIk= -github.com/pion/rtp v1.8.11/go.mod h1:8uMBJj32Pa1wwx8Fuv/AsFhn8jsgw+3rUC2PfoBZ8p4= -github.com/pion/sctp v1.8.35 h1:qwtKvNK1Wc5tHMIYgTDJhfZk7vATGVHhXbUDfHbYwzA= -github.com/pion/sctp v1.8.35/go.mod h1:EcXP8zCYVTRy3W9xtOF7wJm1L1aXfKRQzaM33SjQlzg= -github.com/pion/sdp/v3 v3.0.10 h1:6MChLE/1xYB+CjumMw+gZ9ufp2DPApuVSnDT8t5MIgA= -github.com/pion/sdp/v3 v3.0.10/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= -github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= -github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ= +github.com/pion/rtp v1.8.19 h1:jhdO/3XhL/aKm/wARFVmvTfq0lC/CvN1xwYKmduly3c= +github.com/pion/rtp v1.8.19/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk= +github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE= +github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= +github.com/pion/sdp/v3 v3.0.13 h1:uN3SS2b+QDZnWXgdr69SM8KB4EbcnPnPf2Laxhty/l4= +github.com/pion/sdp/v3 v3.0.13/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= +github.com/pion/srtp/v3 v3.0.6 h1:E2gyj1f5X10sB/qILUGIkL4C2CqK269Xq167PbGCc/4= +github.com/pion/srtp/v3 v3.0.6/go.mod h1:BxvziG3v/armJHAaJ87euvkhHqWe9I7iiOy50K2QkhY= github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= @@ -259,45 +224,38 @@ github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1A github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q= github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E= -github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= -github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= -github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc= -github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= -github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM= -github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA= -github.com/pion/webrtc/v4 v4.0.8 h1:T1ZmnT9qxIJIt4d8XoiMOBrTClGHDDXNg9e/fh018Qc= -github.com/pion/webrtc/v4 v4.0.8/go.mod h1:HHBeUVBAC+j4ZFnYhovEFStF02Arb1EyD4G7e7HBTJw= +github.com/pion/turn/v4 v4.0.2 h1:ZqgQ3+MjP32ug30xAbD6Mn+/K4Sxi3SdNOTFf+7mpps= +github.com/pion/turn/v4 v4.0.2/go.mod h1:pMMKP/ieNAG/fN5cZiN4SDuyKsXtNTr0ccN7IToA1zs= +github.com/pion/webrtc/v4 v4.1.2 h1:mpuUo/EJ1zMNKGE79fAdYNFZBX790KE7kQQpLMjjR54= +github.com/pion/webrtc/v4 v4.1.2/go.mod h1:xsCXiNAmMEjIdFxAYU0MbB3RwRieJsegSB2JZsGN+8U= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= -github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= -github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= -github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4= +github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= -github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.49.0 h1:w5iJHXwHxs1QxyBv1EHKuC50GX5to8mJAxvtnttJp94= -github.com/quic-go/quic-go v0.49.0/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= +github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47pA= +github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw= -github.com/raulk/go-watchdog v1.3.0 h1:oUmdlHxdkXRJlwfG0O9omj8ukerm8MEQavSiDTEtBsk= -github.com/raulk/go-watchdog v1.3.0/go.mod h1:fIvOnLbF0b0ZwkB9YU4mOW9Did//4vPZtDqv66NfsMU= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= @@ -319,10 +277,8 @@ github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= -github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= @@ -331,9 +287,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= @@ -341,7 +294,6 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= -github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= @@ -349,23 +301,20 @@ github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw= -go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= -go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg= -go.uber.org/fx v1.23.0/go.mod h1:o/D9n+2mLP6v1EG+qsdT1O8wKopYAsqZasju97SDFCU= -go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= +go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= +go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4= +go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= +go.uber.org/fx v1.24.0 h1:wE8mruvpg2kiiL1Vqd0CC+tr0/24XIB10Iwp2lLWzkg= +go.uber.org/fx v1.24.0/go.mod h1:AmDeGyS+ZARGKM4tlH4FY2Jr63VjbEDJHtqXTGP5hbo= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= -go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= @@ -382,24 +331,22 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= -golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc= -golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= +golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 h1:bsqhLWFR6G6xiQcb+JoGqdKdRU6WzPWmK8E0jxTjzo4= +golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= -golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -407,7 +354,6 @@ golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -415,7 +361,6 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -423,8 +368,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -440,38 +385,31 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20180810173357-98c5dad5d1a0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -488,28 +426,25 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= -golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -530,26 +465,22 @@ google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmE google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM= -google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -lukechampine.com/blake3 v1.3.0 h1:sJ3XhFINmHSrYCgl958hscfIa3bw8x4DqMP3u1YvoYE= -lukechampine.com/blake3 v1.3.0/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k= +lukechampine.com/blake3 v1.4.1 h1:I3Smz7gso8w4/TunLKec6K2fn+kyKtDxr/xcQEN84Wg= +lukechampine.com/blake3 v1.4.1/go.mod h1:QFosUxmjB8mnrWFSNwKmvxHpfY72bmD2tQ0kBMM3kwo= sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/networking/forwarder/main.go b/networking/forwarder/main.go index dd2a9ea4..3699364d 100644 --- a/networking/forwarder/main.go +++ b/networking/forwarder/main.go @@ -11,6 +11,7 @@ import ( ) var nodeID = flag.String("node-id", "", "Node ID (defaults to FORWARDER_NODE_ID env var or a new UUID)") +var eventsDBPath = flag.String("events-db", "", "Path to the worker events SQLite database") func main() { flag.Parse() @@ -23,6 +24,12 @@ func main() { } log.Printf("Starting forwarder with node ID: %s", id) + // Set the events database path if provided + if *eventsDBPath != "" { + forwarder.SetEventsDBPath(*eventsDBPath) + log.Printf("Using events database: %s", *eventsDBPath) + } + args := flag.Args() if len(args) == 0 { log.Fatal("forwarding pairs argument is required as the first positional argument (of the form {source}|{sink}) where source and sink sqlite:db_file:table_name or libp2p:topic") diff --git a/networking/forwarder/src/event_writer.go b/networking/forwarder/src/event_writer.go new file mode 100644 index 00000000..b0ebb9dd --- /dev/null +++ b/networking/forwarder/src/event_writer.go @@ -0,0 +1,259 @@ +package forwarder + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "strconv" + "sync" + + "github.com/google/uuid" + "github.com/libp2p/go-libp2p/core/network" + _ "github.com/mattn/go-sqlite3" + "github.com/multiformats/go-multiaddr" +) + +var ( + eventsDBPath string + eventsDB *sql.DB + eventsDBMu sync.Mutex +) + +// SetEventsDBPath sets the path to the events database +func SetEventsDBPath(path string) { + eventsDBMu.Lock() + defer eventsDBMu.Unlock() + eventsDBPath = path +} + +// Event types matching Python's _EventType enum +const ( + EventTypeTopologyEdgeCreated = "TopologyEdgeCreated" + EventTypeTopologyEdgeDeleted = "TopologyEdgeDeleted" +) + +// ConnectionProfile matches Python's ConnectionProfile (optional) +type ConnectionProfile struct { + Throughput float64 `json:"throughput"` + Latency float64 `json:"latency"` + Jitter float64 `json:"jitter"` +} + +// Multiaddr matches Python's Multiaddr structure +type Multiaddr struct { + Address string `json:"address"` + IPv4Address string `json:"ipv4_address,omitempty"` + Port int `json:"port,omitempty"` +} + +// Connection matches Python's Connection model +type Connection struct { + LocalNodeID string `json:"local_node_id"` + SendBackNodeID string `json:"send_back_node_id"` + LocalMultiaddr Multiaddr `json:"local_multiaddr"` + SendBackMultiaddr Multiaddr `json:"send_back_multiaddr"` + ConnectionProfile *ConnectionProfile `json:"connection_profile"` +} + +// TopologyEdgeCreated matches Python's TopologyEdgeCreated event +type TopologyEdgeCreated struct { + EventType string `json:"event_type"` + EventID string `json:"event_id"` + Edge Connection `json:"edge"` +} + +// TopologyEdgeDeleted matches Python's TopologyEdgeDeleted event +type TopologyEdgeDeleted struct { + EventType string `json:"event_type"` + EventID string `json:"event_id"` + Edge Connection `json:"edge"` +} + +// initEventsDB initializes the events database connection +func initEventsDB() error { + eventsDBMu.Lock() + defer eventsDBMu.Unlock() + + if eventsDB != nil { + return nil // Already initialized + } + + if eventsDBPath == "" { + return nil // No events DB configured + } + + var err error + eventsDB, err = sql.Open("sqlite3", eventsDBPath) + if err != nil { + return fmt.Errorf("failed to open events database: %w", err) + } + + // Create table if it doesn't exist (matching Python's schema) + createTableSQL := ` + CREATE TABLE IF NOT EXISTS events ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + origin TEXT NOT NULL, + event_type TEXT NOT NULL, + event_id TEXT NOT NULL, + event_data TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_events_origin ON events(origin); + CREATE INDEX IF NOT EXISTS idx_events_event_type ON events(event_type); + CREATE INDEX IF NOT EXISTS idx_events_created_at ON events(created_at); + ` + _, err = eventsDB.Exec(createTableSQL) + if err != nil { + eventsDB.Close() + eventsDB = nil + return fmt.Errorf("failed to create events table: %w", err) + } + + return nil +} + +// writeEvent writes an event to the database +func writeEvent(eventType string, eventData interface{}) error { + if eventsDB == nil { + if err := initEventsDB(); err != nil { + return err + } + if eventsDB == nil { + return nil // No events DB configured + } + } + + // Serialize event data to JSON + jsonData, err := json.Marshal(eventData) + if err != nil { + return fmt.Errorf("failed to marshal event data: %w", err) + } + + // Extract event ID from the event data + var eventID string + switch e := eventData.(type) { + case *TopologyEdgeCreated: + eventID = e.EventID + case *TopologyEdgeDeleted: + eventID = e.EventID + default: + eventID = uuid.New().String() + } + + // Insert event into database + insertSQL := `INSERT INTO events (origin, event_type, event_id, event_data) VALUES (?, ?, ?, ?)` + _, err = eventsDB.Exec(insertSQL, GetNodeId(), eventType, eventID, string(jsonData)) + if err != nil { + return fmt.Errorf("failed to insert event: %w", err) + } + + return nil +} + +// NotifeeHandler implements the libp2p network.Notifiee interface +type NotifeeHandler struct{} + +// Listen is called when network starts listening on an addr +func (n *NotifeeHandler) Listen(net network.Network, ma multiaddr.Multiaddr) {} + +// ListenClose is called when network stops listening on an addr +func (n *NotifeeHandler) ListenClose(net network.Network, ma multiaddr.Multiaddr) {} + +// Connected is called when a connection is opened +func (n *NotifeeHandler) Connected(net network.Network, conn network.Conn) { + remotePeer := conn.RemotePeer() + localAddr := conn.LocalMultiaddr() + remoteAddr := conn.RemoteMultiaddr() + + // Get the actual node IDs (not peer IDs) + localNodeID := GetNodeId() + + // For remote node, we need to extract from peer ID or use a mapping + // For now, we'll use the peer ID as a placeholder + // TODO: Implement proper node ID mapping/discovery + remoteNodeID := remotePeer.String() + + // Create connection event + event := &TopologyEdgeCreated{ + EventType: EventTypeTopologyEdgeCreated, + EventID: uuid.New().String(), + Edge: Connection{ + LocalNodeID: localNodeID, + SendBackNodeID: remoteNodeID, + LocalMultiaddr: parseMultiaddr(localAddr), + SendBackMultiaddr: parseMultiaddr(remoteAddr), + ConnectionProfile: nil, // TODO: Add connection profiling if needed + }, + } + + // Write event to database + if err := writeEvent(EventTypeTopologyEdgeCreated, event); err != nil { + log.Printf("Failed to write edge created event: %v", err) + } else { + log.Printf("Wrote edge created event: %s -> %s", localNodeID, remoteNodeID) + } +} + +// Disconnected is called when a connection is closed +func (n *NotifeeHandler) Disconnected(net network.Network, conn network.Conn) { + remotePeer := conn.RemotePeer() + localAddr := conn.LocalMultiaddr() + remoteAddr := conn.RemoteMultiaddr() + + // Get the actual node IDs (not peer IDs) + localNodeID := GetNodeId() + remoteNodeID := remotePeer.String() // TODO: Implement proper node ID mapping + + // Create disconnection event + event := &TopologyEdgeDeleted{ + EventType: EventTypeTopologyEdgeDeleted, + EventID: uuid.New().String(), + Edge: Connection{ + LocalNodeID: localNodeID, + SendBackNodeID: remoteNodeID, + LocalMultiaddr: parseMultiaddr(localAddr), + SendBackMultiaddr: parseMultiaddr(remoteAddr), + ConnectionProfile: nil, + }, + } + + // Write event to database + if err := writeEvent(EventTypeTopologyEdgeDeleted, event); err != nil { + log.Printf("Failed to write edge deleted event: %v", err) + } else { + log.Printf("Wrote edge deleted event: %s -> %s", localNodeID, remoteNodeID) + } +} + +// OpenedStream is called when a stream is opened +func (n *NotifeeHandler) OpenedStream(net network.Network, str network.Stream) {} + +// ClosedStream is called when a stream is closed +func (n *NotifeeHandler) ClosedStream(net network.Network, str network.Stream) {} + +// parseMultiaddr converts a libp2p multiaddr to our Multiaddr struct +func parseMultiaddr(ma multiaddr.Multiaddr) Multiaddr { + result := Multiaddr{ + Address: ma.String(), + } + + // Extract IPv4 address if present + if ipStr, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil { + result.IPv4Address = ipStr + } + + // Extract port if present + if portStr, err := ma.ValueForProtocol(multiaddr.P_TCP); err == nil { + if port, err := strconv.Atoi(portStr); err == nil { + result.Port = port + } + } + + return result +} + +// GetNotifee returns a singleton instance of the notifee handler +func GetNotifee() network.Notifiee { + return &NotifeeHandler{} +} \ No newline at end of file diff --git a/networking/forwarder/src/libp2p.go b/networking/forwarder/src/libp2p.go index 584e2b04..d25b1811 100644 --- a/networking/forwarder/src/libp2p.go +++ b/networking/forwarder/src/libp2p.go @@ -6,6 +6,10 @@ import ( "crypto/sha256" "encoding/json" "log" + "net" + "os" + "sort" + "strings" "sync" "time" @@ -15,9 +19,11 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" mdns "github.com/libp2p/go-libp2p/p2p/discovery/mdns" "github.com/libp2p/go-libp2p/p2p/security/noise" + "github.com/multiformats/go-multiaddr" ) var node host.Host @@ -28,22 +34,337 @@ var mu sync.Mutex var refCount int var topicsMap = make(map[string]*pubsub.Topic) +// Connection retry state tracking +type peerConnState struct { + retryCount int + lastAttempt time.Time +} + +var peerLastAddrs = make(map[peer.ID][]multiaddr.Multiaddr) +var addrsMu sync.Mutex + +var connecting = make(map[peer.ID]bool) +var connMu sync.Mutex +var peerRetryState = make(map[peer.ID]*peerConnState) +var retryMu sync.Mutex + +const ( + maxRetries = 5 // Increased for more tolerance to transient failures + initialBackoff = 2 * time.Second + maxBackoff = 33 * time.Second + retryResetTime = 1 * time.Minute // Reduced for faster recovery after max retries +) + type discoveryNotifee struct { h host.Host } +// sortAddrs returns a sorted copy of addresses for comparison +func sortAddrs(addrs []multiaddr.Multiaddr) []multiaddr.Multiaddr { + s := make([]multiaddr.Multiaddr, len(addrs)) + copy(s, addrs) + sort.Slice(s, func(i, j int) bool { + return s[i].String() < s[j].String() + }) + return s +} + +// addrsChanged checks if two address sets differ +func addrsChanged(a, b []multiaddr.Multiaddr) bool { + if len(a) != len(b) { + return true + } + sa := sortAddrs(a) + sb := sortAddrs(b) + for i := range sa { + if !sa[i].Equal(sb[i]) { + return true + } + } + return false +} + +// isAddressValid checks if an address should be used for connections +func isAddressValid(addr multiaddr.Multiaddr) bool { + // Allow loopback for testing if env var is set + allowLoopback := os.Getenv("FORWARDER_ALLOW_LOOPBACK") == "true" + + // Check IPv4 addresses + ipStr, err := addr.ValueForProtocol(multiaddr.P_IP4) + if err == nil && ipStr != "" { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + // Filter out loopback, unspecified addresses (unless testing) + if !allowLoopback && (ip.IsLoopback() || ip.IsUnspecified()) { + return false + } + if ip.IsUnspecified() { + return false + } + // Filter out common VPN ranges (Tailscale uses 100.64.0.0/10) + if ip.To4() != nil && ip.To4()[0] == 100 && ip.To4()[1] >= 64 && ip.To4()[1] <= 127 { + return false + } + } + + // Check IPv6 addresses + ipStr, err = addr.ValueForProtocol(multiaddr.P_IP6) + if err == nil && ipStr != "" { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + // Filter out loopback, unspecified addresses (unless testing) + if !allowLoopback && (ip.IsLoopback() || ip.IsUnspecified()) { + return false + } + if ip.IsUnspecified() { + return false + } + // Filter out Tailscale IPv6 (fd7a:115c:a1e0::/48) + if strings.HasPrefix(strings.ToLower(ipStr), "fd7a:115c:a1e0:") { + return false + } + } + + return true +} + +// customInterfaceAddresses returns IPs only from interfaces that are up and running (has link) +func customInterfaceAddresses() ([]net.IP, error) { + var ips []net.IP + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, ifi := range ifaces { + if ifi.Flags&net.FlagUp == 0 || ifi.Flags&net.FlagRunning == 0 { + continue + } + addrs, err := ifi.Addrs() + if err != nil { + return nil, err + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP != nil { + ips = append(ips, ipnet.IP) + } + } + } + return ips, nil +} + +// customAddrsFactory expands wildcard listen addrs to actual IPs on up+running interfaces, then filters +func customAddrsFactory(listenAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr { + ips, err := customInterfaceAddresses() + if err != nil { + log.Printf("Error getting interface IPs: %v", err) + return nil + } + + var advAddrs []multiaddr.Multiaddr + for _, la := range listenAddrs { + comps := multiaddr.Split(la) + if len(comps) == 0 { + continue + } + first := comps[0] + protos := first.Protocols() + if len(protos) == 0 { + continue + } + code := protos[0].Code + val, err := first.ValueForProtocol(code) + var isWildcard bool + if err == nil && ((code == multiaddr.P_IP4 && val == "0.0.0.0") || (code == multiaddr.P_IP6 && val == "::")) { + isWildcard = true + } + + if isWildcard { + // Expand to each valid IP + for _, ip := range ips { + var pcodeStr string + if ip.To4() != nil { + pcodeStr = "4" + } else { + pcodeStr = "6" + } + newIPStr := "/ip" + pcodeStr + "/" + ip.String() + newIPMA, err := multiaddr.NewMultiaddr(newIPStr) + if err != nil { + continue + } + var newComps []multiaddr.Multiaddrer + newComps = append(newComps, newIPMA) + for _, c := range comps[1:] { + newComps = append(newComps, c.Multiaddr()) + } + newa := multiaddr.Join(newComps...) + if isAddressValid(newa) { + advAddrs = append(advAddrs, newa) + } + } + } else if isAddressValid(la) { + advAddrs = append(advAddrs, la) + } + } + return advAddrs +} + func (n *discoveryNotifee) HandlePeerFound(pi peer.AddrInfo) { - if n.h.ID() >= pi.ID { - return - } + log.Printf("mDNS discovered peer %s with %d addresses", pi.ID, len(pi.Addrs)) + + // Check if already connected first if n.h.Network().Connectedness(pi.ID) == network.Connected { + log.Printf("Already connected to peer %s", pi.ID) return } - ctx := context.Background() + + // Clear any existing addresses for this peer to ensure we use only fresh ones from mDNS + ps := n.h.Peerstore() + ps.ClearAddrs(pi.ID) + log.Printf("Cleared old addresses for peer %s", pi.ID) + + // During normal operation, only higher ID connects to avoid double connections + // But if we have retry state for this peer, both sides should attempt + // Also, if we have no connections at all, both sides should attempt + retryMu.Lock() + _, hasRetryState := peerRetryState[pi.ID] + retryMu.Unlock() + + // Check if we should skip based on ID comparison + // Skip only if we have a higher ID, no retry state, and we already have connections + if n.h.ID() >= pi.ID && !hasRetryState && len(n.h.Network().Peers()) > 0 { + log.Printf("Skipping initial connection to peer %s (lower ID)", pi.ID) + return + } + + // Filter addresses before attempting connection + var filteredAddrs []multiaddr.Multiaddr + for _, addr := range pi.Addrs { + if isAddressValid(addr) { + filteredAddrs = append(filteredAddrs, addr) + log.Printf("Valid address for %s: %s", pi.ID, addr) + } else { + log.Printf("Filtered out address for %s: %s", pi.ID, addr) + } + } + + if len(filteredAddrs) == 0 { + log.Printf("No valid addresses for peer %s after filtering, skipping connection attempt", pi.ID) + return + } + + // Check for address changes and reset retries if changed + addrsMu.Lock() + lastAddrs := peerLastAddrs[pi.ID] + addrsMu.Unlock() + if addrsChanged(lastAddrs, filteredAddrs) { + log.Printf("Detected address change for peer %s, resetting retry count", pi.ID) + retryMu.Lock() + if state, ok := peerRetryState[pi.ID]; ok { + state.retryCount = 0 + } + retryMu.Unlock() + // Update last known addresses + addrsMu.Lock() + peerLastAddrs[pi.ID] = append([]multiaddr.Multiaddr(nil), filteredAddrs...) // Copy + addrsMu.Unlock() + } + + pi.Addrs = filteredAddrs + + // Add the filtered addresses to the peerstore with a reasonable TTL + ps.AddAddrs(pi.ID, filteredAddrs, peerstore.TempAddrTTL) + + // Attempt connection with retry logic + go n.connectWithRetry(pi) +} + +func (n *discoveryNotifee) connectWithRetry(pi peer.AddrInfo) { + // Serialize connection attempts per peer + connMu.Lock() + if connecting[pi.ID] { + connMu.Unlock() + log.Printf("Already connecting to peer %s, skipping duplicate attempt", pi.ID) + return + } + connecting[pi.ID] = true + connMu.Unlock() + defer func() { + connMu.Lock() + delete(connecting, pi.ID) + connMu.Unlock() + }() + + retryMu.Lock() + state, exists := peerRetryState[pi.ID] + if !exists { + state = &peerConnState{} + peerRetryState[pi.ID] = state + } + + // Check if we've exceeded max retries + if state.retryCount >= maxRetries { + // Check if enough time has passed to reset retry count + if time.Since(state.lastAttempt) > retryResetTime { + state.retryCount = 0 + log.Printf("Reset retry count for peer %s due to time elapsed", pi.ID) + } else { + retryMu.Unlock() + log.Printf("Max retries reached for peer %s, skipping", pi.ID) + return + } + } + + // Calculate backoff duration + backoffDuration := time.Duration(1< maxBackoff { + backoffDuration = maxBackoff + } + + // Check if we need to wait before retrying + if state.retryCount > 0 && time.Since(state.lastAttempt) < backoffDuration { + retryMu.Unlock() + log.Printf("Backoff active for peer %s, skipping attempt", pi.ID) + return + } + + state.lastAttempt = time.Now() + retryMu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := n.h.Connect(ctx, pi); err != nil { - log.Printf("Failed to connect to %s: %v", pi.ID.String(), err) + log.Printf("Failed to connect to %s (attempt %d/%d): %v", pi.ID, state.retryCount+1, maxRetries, err) + + retryMu.Lock() + state.retryCount++ + retryMu.Unlock() + + // Schedule retry if we haven't exceeded max attempts + if state.retryCount < maxRetries { + time.AfterFunc(backoffDuration, func() { + // Check if we're still not connected before retrying + if n.h.Network().Connectedness(pi.ID) != network.Connected { + n.connectWithRetry(pi) + } + }) + } } else { - log.Printf("Connected to %s", pi.ID.String()) + log.Printf("Successfully connected to %s", pi.ID) + + // Reset retry state on successful connection + retryMu.Lock() + delete(peerRetryState, pi.ID) + retryMu.Unlock() + addrsMu.Lock() + delete(peerLastAddrs, pi.ID) + addrsMu.Unlock() + log.Printf("Cleared last addresses for disconnected peer %s", pi.ID) } } @@ -76,6 +397,9 @@ func getNode(ctx context.Context) { opts = append(opts, libp2p.EnableHolePunching()) // Better NAT traversal opts = append(opts, libp2p.EnableRelay()) // Allow relaying + // Custom address factory to avoid advertising down interfaces + opts = append(opts, libp2p.AddrsFactory(customAddrsFactory)) + node, err = libp2p.New(opts...) if err != nil { log.Fatalf("failed to create host: %v", err) @@ -103,9 +427,118 @@ func getNode(ctx context.Context) { node.Close() log.Fatalf("failed to start mdns service: %v", err) } + + // Register disconnect notifiee to clear stale addresses + node.Network().Notify(&disconnectNotifee{}) + + // Register event notifiee to track topology changes + node.Network().Notify(GetNotifee()) + + // Start a goroutine to periodically trigger mDNS discovery + go periodicMDNSDiscovery() }) } +// periodicMDNSDiscovery ensures mDNS continues to work after network changes +func periodicMDNSDiscovery() { + // Start with faster checks, then slow down + fastCheckDuration := 5 * time.Second + slowCheckDuration := 30 * time.Second + currentDuration := fastCheckDuration + noConnectionCount := 0 + + ticker := time.NewTicker(currentDuration) + defer ticker.Stop() + + for range ticker.C { + if mdnsSer == nil || node == nil { + return + } + + // Log current connection status + peers := node.Network().Peers() + if len(peers) == 0 { + noConnectionCount++ + log.Printf("No connected peers (check #%d), mDNS service running: %v", noConnectionCount, mdnsSer != nil) + + // Force mDNS to re-announce when we have no peers + // This helps recovery after network interface changes + if noConnectionCount > 1 { // Skip first check to avoid unnecessary restart + forceRestartMDNS() + } + + // Keep fast checking when disconnected + if currentDuration != fastCheckDuration { + currentDuration = fastCheckDuration + ticker.Reset(currentDuration) + log.Printf("Switching to fast mDNS checks (every %v)", currentDuration) + } + } else { + log.Printf("Currently connected to %d peers", len(peers)) + noConnectionCount = 0 + + // Switch to slow checking when connected + if currentDuration != slowCheckDuration { + currentDuration = slowCheckDuration + ticker.Reset(currentDuration) + log.Printf("Switching to slow mDNS checks (every %v)", currentDuration) + } + } + } +} + +// forceRestartMDNS restarts the mDNS service to force re-announcement +func forceRestartMDNS() { + mu.Lock() + defer mu.Unlock() + + if mdnsSer != nil && node != nil { + log.Printf("Force restarting mDNS service for re-announcement") + oldMdns := mdnsSer + rendezvous := "forwarder_network" + notifee := &discoveryNotifee{h: node} + newMdns := mdns.NewMdnsService(node, rendezvous, notifee) + + if err := newMdns.Start(); err != nil { + log.Printf("Failed to restart mDNS service: %v", err) + } else { + oldMdns.Close() + mdnsSer = newMdns + log.Printf("Successfully restarted mDNS service") + } + } +} + +// disconnectNotifee clears stale peer addresses on disconnect +type disconnectNotifee struct{} + +func (d *disconnectNotifee) Connected(network.Network, network.Conn) {} +func (d *disconnectNotifee) Disconnected(n network.Network, c network.Conn) { + p := c.RemotePeer() + ps := n.Peerstore() + + // Clear all addresses from peerstore to force fresh discovery on reconnect + ps.ClearAddrs(p) + + // Also clear retry state for this peer + retryMu.Lock() + delete(peerRetryState, p) + retryMu.Unlock() + + log.Printf("Cleared stale addresses and retry state for disconnected peer %s", p) + + // Try to restart mDNS discovery after a short delay to handle network interface changes + go func() { + time.Sleep(2 * time.Second) + log.Printf("Triggering mDNS re-discovery after disconnect") + forceRestartMDNS() + }() +} +func (d *disconnectNotifee) OpenedStream(network.Network, network.Stream) {} +func (d *disconnectNotifee) ClosedStream(network.Network, network.Stream) {} +func (d *disconnectNotifee) Listen(network.Network, multiaddr.Multiaddr) {} +func (d *disconnectNotifee) ListenClose(network.Network, multiaddr.Multiaddr) {} + type libP2PConnector struct { topic string sub *pubsub.Subscription diff --git a/pyproject.toml b/pyproject.toml index 7d8aad79..2404533f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,4 +113,8 @@ extend-select = ["I", "N", "B", "A", "PIE", "SIM"] [tool.pytest.ini_options] pythonpath = "." -asyncio_mode = "auto" \ No newline at end of file +asyncio_mode = "auto" +markers = [ + "slow: marks tests as slow (deselected by default)" +] +addopts = "-m 'not slow'" diff --git a/run.sh b/run.sh index f63eea07..c32b9345 100755 --- a/run.sh +++ b/run.sh @@ -40,7 +40,7 @@ fi # Second command (master) - changes based on replica flag if [ "$REPLICA" = true ]; then - osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export EXO_RUN_AS_REPLICA=1 EXO_HOME=.exo_replica API_PORT=8001; uv run -m master.main'\"" + osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export RUST_LOG=true EXO_RUN_AS_REPLICA=1 EXO_HOME=.exo_replica API_PORT=8001; uv run -m master.main'\"" else - osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c uv run -m master.main\"" + osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export RUST_LOG=true; uv run -m master.main'\"" fi \ No newline at end of file diff --git a/rust/discovery/src/behaviour.rs b/rust/discovery/src/behaviour.rs index 15efe265..382fe241 100644 --- a/rust/discovery/src/behaviour.rs +++ b/rust/discovery/src/behaviour.rs @@ -200,7 +200,7 @@ fn mdns_behaviour(keypair: &identity::Keypair) -> AnyResult enable IPv6 let mdns_config = Config { - // enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work + enable_ipv6: true, ..Default::default() }; diff --git a/rust/discovery/src/lib.rs b/rust/discovery/src/lib.rs index bcc1075a..b1a5abdc 100644 --- a/rust/discovery/src/lib.rs +++ b/rust/discovery/src/lib.rs @@ -17,6 +17,7 @@ use crate::behaviour::{discovery_behaviour, DiscoveryBehaviour}; use crate::transport::discovery_transport; use libp2p::{identity, Swarm, SwarmBuilder}; +use std::net::IpAddr; pub mod behaviour; pub mod transport; @@ -49,11 +50,18 @@ pub fn discovery_swarm(keypair: identity::Keypair) -> alias::AnyResult log::info!("RUST: Successfully listening on IPv6"), + Err(e) => log::warn!("RUST: Failed to listen on IPv6 (this is okay if IPv6 is not available): {:?}", e), + } Ok(swarm) } diff --git a/rust/discovery/src/transport.rs b/rust/discovery/src/transport.rs index ee7213d8..189d65c5 100644 --- a/rust/discovery/src/transport.rs +++ b/rust/discovery/src/transport.rs @@ -33,7 +33,8 @@ fn tcp_transport( }; // `TCP_NODELAY` enabled => avoid latency - let tcp_config = Config::default().nodelay(true); + let tcp_config = Config::default() + .nodelay(true); // V1 + lazy flushing => 0-RTT negotiation let upgrade_version = Version::V1Lazy; diff --git a/rust/exo_pyo3_bindings/src/discovery.rs b/rust/exo_pyo3_bindings/src/discovery.rs index 3ba8bbc6..37772807 100644 --- a/rust/exo_pyo3_bindings/src/discovery.rs +++ b/rust/exo_pyo3_bindings/src/discovery.rs @@ -18,12 +18,14 @@ use libp2p::multiaddr::multiaddr; use libp2p::swarm::dial_opts::DialOpts; use libp2p::swarm::{ConnectionId, SwarmEvent, ToSwarm}; use libp2p::{Multiaddr, PeerId, Swarm, gossipsub, mdns}; +use std::net::IpAddr; use pyo3::prelude::{PyModule, PyModuleMethods as _}; use pyo3::{Bound, Py, PyObject, PyResult, PyTraverseError, PyVisit, Python, pymethods}; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use std::convert::identity; use std::error::Error; use tokio::sync::mpsc; +use tokio::time::{interval, Duration}; struct ConnectionUpdate { /// Identity of the peer that we have connected to. @@ -77,6 +79,46 @@ enum IncomingDiscoveryMessage { AddDisconnectedCallback(Box>), } +/// Check if a multiaddr is valid for connection +fn is_address_valid(addr: &Multiaddr) -> bool { + use libp2p::multiaddr::Protocol; + + for component in addr.iter() { + match component { + Protocol::Ip4(ip) => { + let ip_addr = IpAddr::V4(ip); + // Filter out loopback and unspecified addresses + if ip_addr.is_loopback() || ip_addr.is_unspecified() { + return false; + } + // Filter out Tailscale ranges (100.64.0.0/10) + if let IpAddr::V4(ipv4) = ip_addr { + let octets = ipv4.octets(); + if octets[0] == 100 && octets[1] >= 64 && octets[1] <= 127 { + return false; + } + } + } + Protocol::Ip6(ip) => { + let ip_addr = IpAddr::V6(ip); + // Filter out loopback and unspecified addresses + if ip_addr.is_loopback() || ip_addr.is_unspecified() { + return false; + } + // Filter out Tailscale IPv6 (fd7a:115c:a1e0::/48) + if let IpAddr::V6(ipv6) = ip_addr { + let segments = ipv6.segments(); + if segments[0] == 0xfd7a && segments[1] == 0x115c && segments[2] == 0xa1e0 { + return false; + } + } + } + _ => {} + } + } + true +} + #[allow(clippy::enum_glob_use)] async fn discovery_task( mut receiver: mpsc::Receiver, @@ -93,9 +135,60 @@ async fn discovery_task( // create callbacks list let mut connected_callbacks: Vec>> = vec![]; let mut disconnected_callbacks: Vec>> = vec![]; + + // Create periodic health check timer with adaptive interval + let fast_check_duration = Duration::from_secs(5); + let slow_check_duration = Duration::from_secs(30); + let mut health_check_interval = interval(fast_check_duration); + let mut no_connection_count = 0; loop { tokio::select! { + _ = health_check_interval.tick() => { + // Check connection health periodically + let connected_peers = swarm.connected_peers().count(); + if connected_peers == 0 { + no_connection_count += 1; + log::info!("RUST: No connected peers (check #{no_connection_count})"); + + // Keep fast checking when disconnected + if health_check_interval.period() != fast_check_duration { + health_check_interval = interval(fast_check_duration); + log::info!("RUST: Switching to fast health checks (every {:?})", fast_check_duration); + } + + // Force mDNS restart after multiple failed checks + if no_connection_count > 1 { // Trigger faster, after 2 checks + log::info!("RUST: Attempting to restart mDNS discovery"); + // Note: In rust-libp2p, we can't easily restart mDNS like in Go, + // but we can force a re-announce by changing listening addresses + // This is a workaround to trigger mDNS to re-announce + + // Try listening on a new ephemeral port to force re-announcement + match swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse().unwrap()) { + Ok(_) => log::info!("RUST: Added new listener to force mDNS re-announcement"), + Err(e) => log::error!("RUST: Failed to add new listener: {e:?}"), + } + + // Also try IPv6 + match swarm.listen_on("/ip6/::/tcp/0".parse().unwrap()) { + Ok(_) => log::info!("RUST: Added IPv6 listener to force mDNS re-announcement"), + Err(e) => log::error!("RUST: Failed to add IPv6 listener: {e:?}"), + } + } + } else { + if no_connection_count > 0 { + log::info!("RUST: Connection restored, currently connected to {connected_peers} peers"); + } + no_connection_count = 0; + + // Switch to slow checking when connected + if health_check_interval.period() != slow_check_duration { + health_check_interval = interval(slow_check_duration); + log::info!("RUST: Switching to slow health checks (every {:?})", slow_check_duration); + } + } + } message = receiver.recv() => { // handle closed channel let Some(message) = message else { @@ -120,6 +213,13 @@ async fn discovery_task( Behaviour(Mdns(Discovered(list))) => { for (peer_id, multiaddr) in list { log::info!("RUST: mDNS discovered a new peer: {peer_id} on {multiaddr}"); + + // Filter out invalid addresses + if !is_address_valid(&multiaddr) { + log::info!("RUST: Filtered out invalid address: {multiaddr}"); + continue; + } + let local_peer_id = *swarm.local_peer_id(); // To avoid simultaneous dial races, only the lexicographically larger peer_id dials. if peer_id > local_peer_id { @@ -234,12 +334,36 @@ async fn discovery_task( send_back_addr: send_back_addr.clone(), }); } + + // If this was the last connection to the peer, try to force mDNS re-discovery + if num_established == 0 { + log::info!("RUST: Last connection to peer {peer_id} closed, triggering mDNS re-discovery"); + // Remove from gossipsub to ensure clean state + swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id); + + // Force a listen address change to trigger mDNS re-announcement + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(2)).await; + log::info!("RUST: Delayed mDNS trigger after disconnect"); + }); + } } NewListenAddr { address, .. } => { log::info!("RUST: Local node is listening on {address}"); let local_peer = swarm.local_peer_id(); log::info!("RUST: Local peer_id: {local_peer}"); } + OutgoingConnectionError { peer_id, error, .. } => { + log::error!("RUST: Outgoing connection error to peer {peer_id:?}: {error:?}"); + // Connection failed, might be due to network change + if let Some(peer) = peer_id { + // Remove from gossipsub to allow fresh connection attempts + swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer); + } + } + IncomingConnectionError { send_back_addr, error, .. } => { + log::error!("RUST: Incoming connection error from {send_back_addr}: {error:?}"); + } e => { log::debug!("RUST: Other event {e:?}"); } diff --git a/shared/apply/apply.py b/shared/apply/apply.py index 18914590..abb0b05b 100644 --- a/shared/apply/apply.py +++ b/shared/apply/apply.py @@ -1,14 +1,13 @@ +from __future__ import annotations + import copy from functools import singledispatch -from typing import Mapping, TypeVar +from typing import Mapping -# from shared.topology import Topology from shared.types.common import NodeId from shared.types.events import ( - ChunkGenerated, Event, EventFromEventLog, - Heartbeat, InstanceActivated, InstanceCreated, InstanceDeactivated, @@ -35,20 +34,25 @@ from shared.types.worker.common import NodeStatus, RunnerId from shared.types.worker.instances import Instance, InstanceId, InstanceStatus from shared.types.worker.runners import RunnerStatus -S = TypeVar("S", bound=State) @singledispatch def event_apply(event: Event, state: State) -> State: + """Apply an event to *state*. + + Events decorated with ``@no_op_event`` set ``__no_apply__ = True`` on the + class. Such events are considered *no-ops* and therefore leave the state + unchanged without requiring a dedicated handler in this dispatch table. + """ + + if getattr(event, "__no_apply__", False): + return state + raise RuntimeError(f"no handler registered for event type {type(event).__name__}") def apply(state: State, event: EventFromEventLog[Event]) -> State: new_state: State = event_apply(event.event, state) return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log}) -@event_apply.register(Heartbeat) -def apply_heartbeat(event: Heartbeat, state: State) -> State: - return state - @event_apply.register(TaskCreated) def apply_task_created(event: TaskCreated, state: State) -> State: new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task} @@ -148,10 +152,6 @@ def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> Sta new_node_status: Mapping[NodeId, NodeStatus] = {**state.node_status, event.node_id: event.node_state} return state.model_copy(update={"node_status": new_node_status}) -@event_apply.register(ChunkGenerated) -def apply_chunk_generated(event: ChunkGenerated, state: State) -> State: - return state - @event_apply.register(TopologyNodeCreated) def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State: topology = copy.copy(state.topology) @@ -164,6 +164,13 @@ def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> Sta def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State: topology = copy.copy(state.topology) topology.add_connection(event.edge) + opposite_edge = Connection( + local_node_id=event.edge.send_back_node_id, + send_back_node_id=event.edge.local_node_id, + local_multiaddr=event.edge.send_back_multiaddr, + send_back_multiaddr=event.edge.local_multiaddr + ) + topology.add_connection(opposite_edge) return state.model_copy(update={"topology": topology}) @event_apply.register(TopologyEdgeReplacedAtomically) diff --git a/shared/db/sqlite/connector.py b/shared/db/sqlite/connector.py index d03dbd61..df328367 100644 --- a/shared/db/sqlite/connector.py +++ b/shared/db/sqlite/connector.py @@ -1,6 +1,7 @@ import asyncio import contextlib import json +import random from asyncio import Queue, Task from collections.abc import Sequence from logging import Logger, getLogger @@ -8,8 +9,8 @@ from pathlib import Path from typing import Any, cast from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlmodel import SQLModel +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, create_async_engine from shared.types.events import Event, EventParser, NodeId from shared.types.events._events import Heartbeat @@ -81,7 +82,8 @@ class AsyncSQLiteEventStorage: async def get_events_since( self, - last_idx: int + last_idx: int, + ignore_no_op_events: bool = False ) -> Sequence[EventFromEventLog[Event]]: """Retrieve events after a specific index.""" if self._closed: @@ -107,8 +109,11 @@ class AsyncSQLiteEventStorage: event_data: dict[str, Any] = cast(dict[str, Any], json.loads(raw_event_data)) else: event_data = cast(dict[str, Any], raw_event_data) + event = EventParser.validate_python(event_data) + if ignore_no_op_events and event.__no_apply__: + continue events.append(EventFromEventLog( - event=EventParser.validate_python(event_data), + event=event, origin=NodeId(origin), idx_in_log=rowid # rowid becomes idx_in_log )) @@ -169,17 +174,65 @@ class AsyncSQLiteEventStorage: echo=False, connect_args={ "check_same_thread": False, - } + "timeout": 30.0, # Connection timeout in seconds + }, + pool_pre_ping=True, # Test connections before using them + pool_size=5, + max_overflow=10 ) - # Create tables using SQLModel + # Create tables with proper race condition handling async with self._engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) + # First check if the table exists using SQLite's master table + result = await conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table' AND name='events'") + ) + table_exists = result.fetchone() is not None - # Enable WAL mode and other optimizations - await conn.execute(text("PRAGMA journal_mode=WAL")) - await conn.execute(text("PRAGMA synchronous=NORMAL")) - await conn.execute(text("PRAGMA cache_size=10000")) + if not table_exists: + try: + # Use CREATE TABLE IF NOT EXISTS as a more atomic operation + # This avoids race conditions between check and create + await conn.execute(text(""" + CREATE TABLE IF NOT EXISTS events ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + origin TEXT NOT NULL, + event_type TEXT NOT NULL, + event_id TEXT NOT NULL, + event_data TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """)) + + # Create indexes if they don't exist + await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_origin ON events(origin)")) + await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_event_type ON events(event_type)")) + await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_event_id ON events(event_id)")) + await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_created_at ON events(created_at)")) + await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_origin_created ON events(origin, created_at)")) + + self._logger.info("Events table and indexes created successfully") + except OperationalError as e: + # Even with IF NOT EXISTS, log any unexpected errors + self._logger.error(f"Error creating table: {e}") + # Re-check if table exists now + result = await conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table' AND name='events'") + ) + if result.fetchone() is None: + raise RuntimeError(f"Failed to create events table: {e}") from e + else: + self._logger.info("Events table exists (likely created by another process)") + else: + self._logger.debug("Events table already exists") + + # Enable WAL mode and other optimizations with retry logic + await self._execute_pragma_with_retry(conn, [ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA cache_size=10000", + "PRAGMA busy_timeout=30000" # 30 seconds busy timeout + ]) async def _batch_writer(self) -> None: """Background task that drains the queue and commits batches. @@ -250,6 +303,69 @@ class AsyncSQLiteEventStorage: if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0: self._logger.debug(f"Committed batch of {len(batch)} events") + except OperationalError as e: + if "database is locked" in str(e): + self._logger.warning(f"Database locked during batch commit, will retry: {e}") + # Retry with exponential backoff + await self._commit_batch_with_retry(batch) + else: + self._logger.error(f"Failed to commit batch: {e}") + raise except Exception as e: self._logger.error(f"Failed to commit batch: {e}") raise + + async def _execute_pragma_with_retry(self, conn: AsyncConnection, pragmas: list[str], max_retries: int = 5) -> None: + """Execute PRAGMA statements with retry logic for database lock errors.""" + for pragma in pragmas: + retry_count = 0 + base_delay: float = 0.1 # 100ms + + while retry_count < max_retries: + try: + await conn.execute(text(pragma)) + break + except OperationalError as e: + if "database is locked" in str(e) and retry_count < max_retries - 1: + delay = cast(float, base_delay * (2 ** retry_count) + random.uniform(0, 0.1)) + self._logger.warning(f"Database locked on '{pragma}', retry {retry_count + 1}/{max_retries} after {delay:.2f}s") + await asyncio.sleep(delay) + retry_count += 1 + else: + self._logger.error(f"Failed to execute '{pragma}' after {retry_count + 1} attempts: {e}") + raise + + async def _commit_batch_with_retry(self, batch: list[tuple[Event, NodeId]], max_retries: int = 5) -> None: + """Commit a batch with retry logic for database lock errors.""" + retry_count = 0 + base_delay: float = 0.1 # 100ms + + while retry_count < max_retries: + try: + assert self._engine is not None + + async with AsyncSession(self._engine) as session: + for event, origin in batch: + stored_event = StoredEvent( + origin=origin, + event_type=event.event_type, + event_id=str(event.event_id), + event_data=event.model_dump(mode='json') + ) + session.add(stored_event) + + await session.commit() + + if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0: + self._logger.debug(f"Committed batch of {len(batch)} events after {retry_count} retries") + return + + except OperationalError as e: + if "database is locked" in str(e) and retry_count < max_retries - 1: + delay = cast(float, base_delay * (2 ** retry_count) + random.uniform(0, 0.1)) + self._logger.warning(f"Database locked on batch commit, retry {retry_count + 1}/{max_retries} after {delay:.2f}s") + await asyncio.sleep(delay) + retry_count += 1 + else: + self._logger.error(f"Failed to commit batch after {retry_count + 1} attempts: {e}") + raise diff --git a/shared/db/sqlite/event_log_manager.py b/shared/db/sqlite/event_log_manager.py index 266b24ff..a35b0d24 100644 --- a/shared/db/sqlite/event_log_manager.py +++ b/shared/db/sqlite/event_log_manager.py @@ -1,5 +1,8 @@ +import asyncio from logging import Logger -from typing import Dict +from typing import Dict, Optional, cast + +from sqlalchemy.exc import OperationalError from shared.constants import EXO_HOME from shared.db.sqlite.config import EventLogConfig, EventLogType @@ -25,11 +28,34 @@ class EventLogManager: EXO_HOME.mkdir(parents=True, exist_ok=True) # TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor. - async def initialize(self) -> None: - """Initialize both connectors - call this during startup""" + async def initialize(self, max_retries: int = 3) -> None: + """Initialize both connectors with retry logic - call this during startup""" # Both master and worker need both connectors - await self.get_connector(EventLogType.WORKER_EVENTS) - await self.get_connector(EventLogType.GLOBAL_EVENTS) + for log_type in [EventLogType.WORKER_EVENTS, EventLogType.GLOBAL_EVENTS]: + retry_count: int = 0 + last_error: Optional[Exception] = None + + while retry_count < max_retries: + try: + await self.get_connector(log_type) + break + except OperationalError as e: + last_error = e + if "database is locked" in str(e) and retry_count < max_retries - 1: + retry_count += 1 + delay = cast(float, 0.5 * (2 ** retry_count)) + self._logger.warning(f"Database locked while initializing {log_type.value}, retry {retry_count}/{max_retries} after {delay}s") + await asyncio.sleep(delay) + else: + self._logger.error(f"Failed to initialize {log_type.value} after {retry_count + 1} attempts: {e}") + raise RuntimeError(f"Could not initialize {log_type.value} database after {retry_count + 1} attempts") from e + except Exception as e: + self._logger.error(f"Unexpected error initializing {log_type.value}: {e}") + raise + + if retry_count >= max_retries and last_error: + raise RuntimeError(f"Could not initialize {log_type.value} database after {max_retries} attempts") from last_error + self._logger.info("Initialized all event log connectors") async def get_connector(self, log_type: EventLogType) -> AsyncSQLiteEventStorage: @@ -37,20 +63,24 @@ class EventLogManager: if log_type not in self._connectors: db_path = self._config.get_db_path(log_type) - connector = AsyncSQLiteEventStorage( - db_path=db_path, - batch_size=self._config.batch_size, - batch_timeout_ms=self._config.batch_timeout_ms, - debounce_ms=self._config.debounce_ms, - max_age_ms=self._config.max_age_ms, - logger=self._logger - ) - - # Start the connector (creates tables if needed) - await connector.start() - - self._connectors[log_type] = connector - self._logger.info(f"Initialized {log_type.value} connector at {db_path}") + try: + connector = AsyncSQLiteEventStorage( + db_path=db_path, + batch_size=self._config.batch_size, + batch_timeout_ms=self._config.batch_timeout_ms, + debounce_ms=self._config.debounce_ms, + max_age_ms=self._config.max_age_ms, + logger=self._logger + ) + + # Start the connector (creates tables if needed) + await connector.start() + + self._connectors[log_type] = connector + self._logger.info(f"Initialized {log_type.value} connector at {db_path}") + except Exception as e: + self._logger.error(f"Failed to create {log_type.value} connector: {e}") + raise return self._connectors[log_type] diff --git a/shared/topology.py b/shared/topology.py index 9658d483..e8b47520 100644 --- a/shared/topology.py +++ b/shared/topology.py @@ -86,8 +86,11 @@ class Topology(TopologyProto): yield connection def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None: - rx_idx = self._node_id_to_rx_id_map[node_id] - return self._graph.get_node_data(rx_idx).node_profile + try: + rx_idx = self._node_id_to_rx_id_map[node_id] + return self._graph.get_node_data(rx_idx).node_profile + except KeyError: + return None def get_node_multiaddr(self, node_id: NodeId) -> Multiaddr: for connection in self.list_connections(): @@ -106,8 +109,11 @@ class Topology(TopologyProto): self._graph.update_edge_by_index(rx_idx, connection) def get_connection_profile(self, connection: Connection) -> ConnectionProfile | None: - rx_idx = self._edge_id_to_rx_id_map[connection] - return self._graph.get_edge_data_by_index(rx_idx).connection_profile + try: + rx_idx = self._edge_id_to_rx_id_map[connection] + return self._graph.get_edge_data_by_index(rx_idx).connection_profile + except KeyError: + return None def remove_node(self, node_id: NodeId) -> None: rx_idx = self._node_id_to_rx_id_map[node_id] @@ -118,27 +124,22 @@ class Topology(TopologyProto): def remove_connection(self, connection: Connection) -> None: rx_idx = self._edge_id_to_rx_id_map[connection] - print(f"removing connection: {connection}, is bridge: {self._is_bridge(connection)}") if self._is_bridge(connection): # Determine the reference node from which reachability is calculated. # Prefer a master node if the topology knows one; otherwise fall back to # the local end of the connection being removed. reference_node_id: NodeId = self.master_node_id if self.master_node_id is not None else connection.local_node_id orphan_node_ids = self._get_orphan_node_ids(reference_node_id, connection) - print(f"orphan node ids: {orphan_node_ids}") for orphan_node_id in orphan_node_ids: orphan_node_rx_id = self._node_id_to_rx_id_map[orphan_node_id] - print(f"removing orphan node: {orphan_node_id}, rx_id: {orphan_node_rx_id}") self._graph.remove_node(orphan_node_rx_id) del self._node_id_to_rx_id_map[orphan_node_id] + del self._rx_id_to_node_id_map[orphan_node_rx_id] self._graph.remove_edge_from_index(rx_idx) del self._edge_id_to_rx_id_map[connection] if rx_idx in self._rx_id_to_node_id_map: del self._rx_id_to_node_id_map[rx_idx] - - - print(f"topology after edge removal: {self.to_snapshot()}") def get_cycles(self) -> list[list[Node]]: cycle_idxs = rx.simple_cycles(self._graph) @@ -161,14 +162,12 @@ class Topology(TopologyProto): return topology def _is_bridge(self, connection: Connection) -> bool: - edge_idx = self._edge_id_to_rx_id_map[connection] - graph_copy: rx.PyDiGraph[Node, Connection] = self._graph.copy() - components_before = rx.strongly_connected_components(graph_copy) - - graph_copy.remove_edge_from_index(edge_idx) - components_after = rx.strongly_connected_components(graph_copy) - - return components_after > components_before + """Check if removing this connection will orphan any nodes from the master.""" + if self.master_node_id is None: + return False + + orphan_node_ids = self._get_orphan_node_ids(self.master_node_id, connection) + return len(orphan_node_ids) > 0 def _get_orphan_node_ids(self, master_node_id: NodeId, connection: Connection) -> list[NodeId]: """Return node_ids that become unreachable from `master_node_id` once `connection` is removed. diff --git a/shared/types/events/_events.py b/shared/types/events/_events.py index cb092909..b74d185a 100644 --- a/shared/types/events/_events.py +++ b/shared/types/events/_events.py @@ -3,7 +3,9 @@ from enum import Enum from typing import ( TYPE_CHECKING, Annotated, + Any, Literal, + TypeVar, Union, get_args, get_origin, @@ -90,6 +92,7 @@ class _BaseEvent[T: _EventType](BaseModel): event_type: T event_id: EventId = EventId() + __no_apply__: bool = False def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: """Check if the event was sent by the correct node. @@ -99,6 +102,20 @@ class _BaseEvent[T: _EventType](BaseModel): """ return True +_E = TypeVar("_E", bound=_BaseEvent[Any]) + +def no_op_event(cls: type[_E]) -> type[_E]: + """Decorator to mark an event class as a *no-op*. + + Events marked as no-ops do not require an `event_apply` registration – the + apply layer will simply return the current state unchanged. This reduces + boilerplate and keeps console output quieter for high-frequency events + such as *Heartbeat* or streaming *ChunkGenerated* messages. + """ + + cls.__no_apply__ = True # Used by the apply layer to identify no-op events + return cls +@no_op_event class Heartbeat(_BaseEvent[_EventType.Heartbeat]): event_type: Literal[_EventType.Heartbeat] = _EventType.Heartbeat node_id: NodeId @@ -152,6 +169,7 @@ class InstanceReplacedAtomically(_BaseEvent[_EventType.InstanceReplacedAtomicall instance_to_replace: InstanceId new_instance_id: InstanceId +# TODO: RunnerCreated class RunnerStatusUpdated(_BaseEvent[_EventType.RunnerStatusUpdated]): event_type: Literal[_EventType.RunnerStatusUpdated] = _EventType.RunnerStatusUpdated @@ -176,6 +194,7 @@ class WorkerStatusUpdated(_BaseEvent[_EventType.WorkerStatusUpdated]): node_state: NodeStatus +@no_op_event class ChunkGenerated(_BaseEvent[_EventType.ChunkGenerated]): event_type: Literal[_EventType.ChunkGenerated] = _EventType.ChunkGenerated command_id: CommandId diff --git a/shared/types/worker/common.py b/shared/types/worker/common.py index c3b9aeea..754b0af4 100644 --- a/shared/types/worker/common.py +++ b/shared/types/worker/common.py @@ -14,4 +14,3 @@ class RunnerId(ID): class NodeStatus(str, Enum): Idle = "Idle" Running = "Running" - Paused = "Paused" diff --git a/shared/types/worker/ops.py b/shared/types/worker/ops.py index 82db7c77..0987f3c7 100644 --- a/shared/types/worker/ops.py +++ b/shared/types/worker/ops.py @@ -16,7 +16,6 @@ class RunnerOpType(str, Enum): RUNNER_UP = "runner_up" RUNNER_DOWN = "runner_down" RUNNER_FAILED = "runner_failed" - DOWNLOAD = "download" CHAT_COMPLETION = "chat_completion" RunnerOpT = TypeVar("RunnerOpT", bound=RunnerOpType) @@ -47,13 +46,6 @@ class RunnerFailedOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_FAILED]]): op_type: Literal[RunnerOpType.RUNNER_FAILED] = Field(default=RunnerOpType.RUNNER_FAILED, frozen=True) runner_id: RunnerId -class DownloadOp(BaseRunnerOp[Literal[RunnerOpType.DOWNLOAD]]): - op_type: Literal[RunnerOpType.DOWNLOAD] = Field(default=RunnerOpType.DOWNLOAD, frozen=True) - instance_id: InstanceId - runner_id: RunnerId - shard_metadata: ShardMetadata - hosts: list[Host] - class ExecuteTaskOp(BaseRunnerOp[Literal[RunnerOpType.CHAT_COMPLETION]]): op_type: Literal[RunnerOpType.CHAT_COMPLETION] = Field(default=RunnerOpType.CHAT_COMPLETION, frozen=True) runner_id: RunnerId @@ -68,7 +60,6 @@ RunnerOp = Annotated[ RunnerUpOp, RunnerDownOp, RunnerFailedOp, - DownloadOp, ExecuteTaskOp, ], Field(discriminator="op_type") diff --git a/shared/types/worker/runners.py b/shared/types/worker/runners.py index 51a08958..c1428f7e 100644 --- a/shared/types/worker/runners.py +++ b/shared/types/worker/runners.py @@ -12,9 +12,8 @@ from shared.types.worker.shards import ShardMetadata class RunnerStatusType(str, Enum): - Assigned = "Assigned" Downloading = "Downloading" - Ready = "Ready" + Inactive = "Inactive" Starting = "Starting" Loaded = "Loaded" Running = "Running" @@ -28,41 +27,30 @@ class BaseRunnerStatus(BaseModel, Generic[RunnerStatusTypeT]): runner_status: RunnerStatusTypeT -# Emitted by the Master -class AssignedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Assigned]): - runner_status: Literal[RunnerStatusType.Assigned] = Field(default=RunnerStatusType.Assigned) - -# Emitted by the Worker class DownloadingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Downloading]): runner_status: Literal[RunnerStatusType.Downloading] = Field(default=RunnerStatusType.Downloading) download_progress: DownloadProgress -# Emitted by the Worker -class ReadyRunnerStatus(BaseRunnerStatus[RunnerStatusType.Ready]): - runner_status: Literal[RunnerStatusType.Ready] = Field(default=RunnerStatusType.Ready) +class InactiveRunnerStatus(BaseRunnerStatus[RunnerStatusType.Inactive]): + runner_status: Literal[RunnerStatusType.Inactive] = Field(default=RunnerStatusType.Inactive) -# Emitted by the Master class StartingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Starting]): runner_status: Literal[RunnerStatusType.Starting] = Field(default=RunnerStatusType.Starting) -# Emitted by the Worker class LoadedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Loaded]): runner_status: Literal[RunnerStatusType.Loaded] = Field(default=RunnerStatusType.Loaded) -# Emitted by the Worker class RunningRunnerStatus(BaseRunnerStatus[RunnerStatusType.Running]): runner_status: Literal[RunnerStatusType.Running] = Field(default=RunnerStatusType.Running) -# Emitted by the Worker class FailedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Failed]): runner_status: Literal[RunnerStatusType.Failed] = Field(default=RunnerStatusType.Failed) error_message: str | None = None RunnerStatus = Annotated[ - AssignedRunnerStatus - | DownloadingRunnerStatus - | ReadyRunnerStatus + DownloadingRunnerStatus + | InactiveRunnerStatus | StartingRunnerStatus | LoadedRunnerStatus | RunningRunnerStatus diff --git a/worker/common.py b/worker/common.py new file mode 100644 index 00000000..ffbe07db --- /dev/null +++ b/worker/common.py @@ -0,0 +1,35 @@ +from copy import deepcopy +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from shared.types.common import Host +from shared.types.events import ( + InstanceId, + RunnerStatusUpdated, +) +from shared.types.worker.common import RunnerId +from shared.types.worker.runners import ( + RunnerStatus, +) +from shared.types.worker.shards import ShardMetadata +from worker.runner.runner_supervisor import RunnerSupervisor + + +class AssignedRunner(BaseModel): + runner_id: RunnerId + instance_id: InstanceId + shard_metadata: ShardMetadata # just data + hosts: list[Host] + + status: RunnerStatus + failures: list[tuple[float, Exception]] = [] + runner: Optional[RunnerSupervisor] # set if the runner is 'up' + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def status_update_event(self) -> RunnerStatusUpdated: + return RunnerStatusUpdated( + runner_id=self.runner_id, + runner_status=deepcopy(self.status), + ) diff --git a/worker/download/conftest.py b/worker/download/conftest.py index 3c821c98..9f60b97a 100644 --- a/worker/download/conftest.py +++ b/worker/download/conftest.py @@ -1,5 +1,3 @@ -from pathlib import Path - import pytest from shared.models.model_meta import get_model_meta @@ -13,7 +11,7 @@ async def model_meta() -> ModelMetadata: @pytest.fixture -def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path): +def pipeline_shard_meta(model_meta: ModelMetadata): def _pipeline_shard_meta( num_nodes: int = 1, device_rank: int = 0 ) -> PipelineShardMetadata: diff --git a/worker/main.py b/worker/main.py index 0fd25765..01e4d562 100644 --- a/worker/main.py +++ b/worker/main.py @@ -1,658 +1,52 @@ import asyncio import logging -import time -from asyncio import Queue -from copy import deepcopy -from functools import partial -from time import process_time -from typing import AsyncGenerator, Optional - -from pydantic import BaseModel, ConfigDict from shared.apply import apply -from shared.db.sqlite import AsyncSQLiteEventStorage from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from shared.types.common import Host, NodeId +from shared.types.common import NodeId from shared.types.events import ( - ChunkGenerated, - Event, - InstanceDeleted, - InstanceId, - NodePerformanceMeasured, - RunnerDeleted, - RunnerStatusUpdated, - TaskFailed, - TaskStateUpdated, + NodePerformanceMeasured, ) from shared.types.profiling import NodePerformanceProfile -from shared.types.state import State -from shared.types.tasks import TaskId, TaskStatus -from shared.types.worker.common import RunnerId -from shared.types.worker.downloads import ( - DownloadCompleted, - DownloadFailed, - DownloadOngoing, - DownloadProgressData, -) -from shared.types.worker.instances import InstanceStatus from shared.types.worker.ops import ( - AssignRunnerOp, - DownloadOp, - ExecuteTaskOp, - RunnerDownOp, - RunnerFailedOp, - RunnerOp, - RunnerOpType, - RunnerUpOp, - UnassignRunnerOp, + RunnerOp, ) -from shared.types.worker.runners import ( - AssignedRunnerStatus, - DownloadingRunnerStatus, - FailedRunnerStatus, - LoadedRunnerStatus, - ReadyRunnerStatus, - RunnerStatus, - RunnerStatusType, - RunningRunnerStatus, -) -from shared.types.worker.shards import ShardMetadata from shared.utils import get_node_id_keypair from worker.download.impl_shard_downloader import exo_shard_downloader -from worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader -from worker.runner.runner_supervisor import RunnerSupervisor +from worker.plan import plan from worker.utils.profile import start_polling_node_metrics +from worker.worker import Worker -class AssignedRunner(BaseModel): - runner_id: RunnerId - instance_id: InstanceId - shard_metadata: ShardMetadata # just data - hosts: list[Host] - - status: RunnerStatus - failures: list[tuple[float, Exception]] = [] - runner: Optional[RunnerSupervisor] # set if the runner is 'up' - - model_config = ConfigDict(arbitrary_types_allowed=True) - - is_downloaded: bool = False - - def set_is_downloaded(self, is_downloaded: bool) -> None: - self.is_downloaded = is_downloaded - - def status_update_event(self) -> RunnerStatusUpdated: - return RunnerStatusUpdated( - runner_id=self.runner_id, - runner_status=deepcopy(self.status), - ) - -class Worker: - def __init__( - self, - node_id: NodeId, - logger: logging.Logger, - shard_downloader: ShardDownloader, - worker_events: AsyncSQLiteEventStorage | None, - global_events: AsyncSQLiteEventStorage | None, - ): - self.node_id: NodeId = node_id - self.state: State = State() - self.shard_downloader: ShardDownloader = shard_downloader - self.worker_events: AsyncSQLiteEventStorage | None = worker_events # worker_events is None in some tests. - self.global_events: AsyncSQLiteEventStorage | None = global_events - self.logger: logging.Logger = logger - - self.assigned_runners: dict[RunnerId, AssignedRunner] = {} - self._task: asyncio.Task[None] | None = None - - ## Op Executors - - async def _execute_assign_op( - self, op: AssignRunnerOp - ) -> AsyncGenerator[Event, None]: - ''' - Here, we are sure that the model is already downloaded. - This op moves the runner from Assigned -> Ready state. - ''' - self.assigned_runners[op.runner_id] = AssignedRunner( - runner_id=op.runner_id, - instance_id=op.instance_id, - shard_metadata=op.shard_metadata, - hosts=op.hosts, - status=AssignedRunnerStatus(), - runner=None, - ) - - yield self.assigned_runners[op.runner_id].status_update_event() - - async def _execute_unassign_op( - self, op: UnassignRunnerOp - ) -> AsyncGenerator[Event, None]: - if op.runner_id not in self.assigned_runners: - return - - # We can try to do a graceful shutdown of the runner. - runner: RunnerSupervisor | None = self.assigned_runners[op.runner_id].runner - if runner is not None: - await runner.astop() - - # This is all we really need: - del self.assigned_runners[op.runner_id] - yield RunnerDeleted(runner_id=op.runner_id) - - return - yield - - async def _execute_runner_up_op( - self, op: RunnerUpOp, initialize_timeout: Optional[float] = None - ) -> AsyncGenerator[Event, None]: - assigned_runner = self.assigned_runners[op.runner_id] - - # TODO: This should be dynamic, based on the size of the model. - if not initialize_timeout: - gigabytes_per_second = 10 - - shard = assigned_runner.shard_metadata - weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes - - initialize_timeout = weights_size_kb / (1024**2 * gigabytes_per_second) + 2.0 # Add a constant 2.0 to ensure connection can be made as well - - try: - assigned_runner.runner = await asyncio.wait_for( - RunnerSupervisor.create( - model_shard_meta=assigned_runner.shard_metadata, - hosts=assigned_runner.hosts, - logger=self.logger, - ), - timeout=initialize_timeout, - ) - except TimeoutError as e: - import traceback - - tb = traceback.format_exc() - e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}") - async for event in self._fail_runner(e=e, runner_id=op.runner_id): - yield event - return - - if assigned_runner.runner.healthy: - assigned_runner.status = LoadedRunnerStatus() - else: - assigned_runner.status = FailedRunnerStatus() - yield self.assigned_runners[op.runner_id].status_update_event() - - async def _execute_runner_down_op( - self, op: RunnerDownOp - ) -> AsyncGenerator[Event, None]: - assigned_runner = self.assigned_runners[op.runner_id] - - if isinstance(assigned_runner.runner, RunnerSupervisor): - await assigned_runner.runner.astop() - - assigned_runner.runner = None - - assigned_runner.status = ReadyRunnerStatus() - yield assigned_runner.status_update_event() - return - - async def _execute_runner_failed_op( - self, op: RunnerFailedOp - ) -> AsyncGenerator[Event, None]: - ''' - We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down. - ''' - assigned_runner = self.assigned_runners[op.runner_id] - - assigned_runner.status = FailedRunnerStatus() - yield self.assigned_runners[op.runner_id].status_update_event() - - async def _execute_download_op( - self, op: DownloadOp - ) -> AsyncGenerator[Event, None]: - ''' - The model needs assigning and then downloading. - This op moves the runner from Assigned -> Downloading -> Ready state. - ''' - - initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata) - if initial_progress.status == "complete": - self.assigned_runners[op.runner_id].set_is_downloaded(True) - self.assigned_runners[op.runner_id].status = DownloadingRunnerStatus( - download_progress=DownloadCompleted( - node_id=self.node_id, - ) - ) - yield self.assigned_runners[op.runner_id].status_update_event() - self.assigned_runners[op.runner_id].status = ReadyRunnerStatus() - yield self.assigned_runners[op.runner_id].status_update_event() - return - - initial_status = DownloadingRunnerStatus( - download_progress=DownloadOngoing( - node_id=self.node_id, - download_progress=DownloadProgressData( - total_bytes=initial_progress.total_bytes, - downloaded_bytes=initial_progress.downloaded_bytes - ) - ) - ) - - self.assigned_runners[op.runner_id] = AssignedRunner( - runner_id=op.runner_id, - instance_id=op.instance_id, - shard_metadata=op.shard_metadata, - hosts=op.hosts, - status=initial_status, - runner=None, - ) - assigned_runner: AssignedRunner = self.assigned_runners[op.runner_id] - yield assigned_runner.status_update_event() - - # Download it! - # TODO: we probably want download progress as part of a callback that gets passed to the downloader. - download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue() - def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None: - download_progress_queue.put_nowait(progress) - - - self.shard_downloader.on_progress(download_progress_callback) - - asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata)) - - # TODO: Dynamic timeout, timeout on no packet update received. - timeout_secs = 10 * 60 - start_time = process_time() - last_yield_progress = start_time - while process_time() - start_time < timeout_secs: - progress: RepoDownloadProgress = await download_progress_queue.get() - if progress.status == "complete": - assigned_runner.status = DownloadingRunnerStatus( - download_progress=DownloadCompleted( - node_id=self.node_id, - ) - ) - yield assigned_runner.status_update_event() - assigned_runner.set_is_downloaded(True) - assigned_runner.status = ReadyRunnerStatus() - yield assigned_runner.status_update_event() - break - elif progress.status == "in_progress": - if process_time() - last_yield_progress > 1: - assigned_runner.status = DownloadingRunnerStatus( - download_progress=DownloadOngoing( - node_id=self.node_id, - download_progress=DownloadProgressData( - total_bytes=progress.total_bytes, - downloaded_bytes=progress.downloaded_bytes, - ) - ) - ) - yield assigned_runner.status_update_event() - last_yield_progress = process_time() - else: - assigned_runner.status = DownloadingRunnerStatus( - download_progress=DownloadFailed( - node_id=self.node_id, - error_message=f"Timeout downloading model: {op.shard_metadata.model_meta.model_id}" - ) - ) - yield assigned_runner.status_update_event() - - - async def _execute_task_op( - self, op: ExecuteTaskOp - ) -> AsyncGenerator[Event, None]: - ''' - This is the entry point for a chat completion starting. - While there is only one execute function, it will get called in different ways for runner 0 and runner [1, 2, 3, ...]. - Runners [1, 2, 3, ...] will run this method when a task is in 'pending' state. - Runner 0 will run this method when a task is in 'running' state. - TODO: How do we handle the logic of ensuring that n-1 nodes have started their execution before allowing the 0'th runner to start? - This is still a little unclear to me. - ''' - assigned_runner = self.assigned_runners[op.runner_id] - - async def inner_execute(queue: asyncio.Queue[Event]) -> None: - async def running_callback(queue: asyncio.Queue[Event]) -> None: - # Called when the MLX process has been kicked off - assigned_runner.status = RunningRunnerStatus() - await queue.put(assigned_runner.status_update_event()) - - if assigned_runner.shard_metadata.device_rank == 0: - await queue.put(TaskStateUpdated( - task_id=op.task.task_id, - task_status=TaskStatus.RUNNING, - )) - - try: - assert assigned_runner.runner is not None - assert assigned_runner.runner.healthy - - async for chunk in assigned_runner.runner.stream_response( - task=op.task, - request_started_callback=partial(running_callback, queue)): - if assigned_runner.shard_metadata.device_rank == 0: - await queue.put(ChunkGenerated( - # todo: at some point we will no longer have a bijection between task_id and row_id. - # So we probably want to store a mapping between these two in our Worker object. - command_id=chunk.command_id, - chunk=chunk - )) - - if assigned_runner.shard_metadata.device_rank == 0: - await queue.put(TaskStateUpdated( - task_id=op.task.task_id, - task_status=TaskStatus.COMPLETE, - )) - - # After a successful inference: - assigned_runner.status = LoadedRunnerStatus() - await queue.put(assigned_runner.status_update_event()) - - - except Exception as e: - # An exception occurs in the runner supervisor - self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}') - async for event in self._fail_task(e, op.runner_id, op.task.task_id): - await queue.put(event) - - queue: Queue[Event] = asyncio.Queue() - task = asyncio.create_task(inner_execute(queue)) - - # TODO: Initial (prefil) timeout can be dynamic - # model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes - - try: - # Yield items from the queue - # timeout = 30. - timeout = 3. - while True: - item: Event = await asyncio.wait_for(queue.get(), timeout=timeout) - yield item - timeout = 2. - if isinstance(item, RunnerStatusUpdated) and isinstance( - item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus) - ): - if isinstance(item.runner_status, LoadedRunnerStatus): - assigned_runner.failures = [] - - break - except TimeoutError as e: - # Runner supervisor doesn't respond in time; so we put the runner & task into a failed state - self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.') - async for event in self._fail_task(e, op.runner_id, op.task.task_id): - yield event - finally: - # Ensure the task is cleaned up - try: - await asyncio.wait_for(task, timeout=5) - except asyncio.TimeoutError: - self.logger.warning("Timed out waiting for task cleanup after inference execution.") - - - ## Operation Planner - - async def _execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]: - ## It would be great if we can get rid of this async for ... yield pattern. - match op.op_type: - case RunnerOpType.ASSIGN_RUNNER: - event_generator = self._execute_assign_op(op) - case RunnerOpType.UNASSIGN_RUNNER: - event_generator = self._execute_unassign_op(op) - case RunnerOpType.RUNNER_UP: - event_generator = self._execute_runner_up_op(op) - case RunnerOpType.RUNNER_DOWN: - event_generator = self._execute_runner_down_op(op) - case RunnerOpType.RUNNER_FAILED: - event_generator = self._execute_runner_failed_op(op) - case RunnerOpType.DOWNLOAD: - event_generator = self._execute_download_op(op) - case RunnerOpType.CHAT_COMPLETION: - event_generator = self._execute_task_op(op) - - async for event in event_generator: - yield event - - ## Planning logic - def plan(self, state: State) -> RunnerOp | None: - # Compare state to worker 'mood' - - # for runner_id, assigned_runner in self.assigned_runners.items(): - # if len(assigned_runner.failures) == 3: - # raise Exception('Too many error occurred in assigned runner - assumed to be recurrent and unrecoverable.\nErrors are as follows: {assigned_runner.failures}') - - # First, unassign assigned runners that are no longer in the state. - for runner_id, _ in self.assigned_runners.items(): - runner_ids: list[RunnerId] = [ - runner_id - for instance in state.instances.values() - for runner_id in instance.shard_assignments.runner_to_shard - ] - if runner_id not in runner_ids: - return UnassignRunnerOp(runner_id=runner_id) - - for runner_id, assigned_runner in self.assigned_runners.items(): - if assigned_runner.runner is not None and \ - not assigned_runner.runner.healthy and \ - not isinstance(assigned_runner.status, FailedRunnerStatus): - return RunnerFailedOp(runner_id=runner_id) - - # Then spin down active runners - for _instance_id, instance in state.instances.items(): - for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): - if node_id != self.node_id: - continue - - # We spin down a runner if it's meant to be inactive and it's Loaded. - if runner_id in self.assigned_runners and \ - isinstance(self.assigned_runners[runner_id].status, LoadedRunnerStatus) and \ - instance.instance_type == InstanceStatus.INACTIVE: - return RunnerDownOp(runner_id=runner_id) - - # If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down - # TODO: We need to limit number of retries if we keep failing. - for _instance_id, instance in state.instances.items(): - if self.node_id in instance.shard_assignments.node_to_runner and \ - instance.shard_assignments.node_to_runner[self.node_id] in self.assigned_runners and \ - not isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, ReadyRunnerStatus): # make sure that our runner has not already been spun down into ready state - other_node_in_instance_has_failed = False - for runner_id in instance.shard_assignments.runner_to_shard: - if runner_id in state.runners and \ - isinstance(state.runners[runner_id], FailedRunnerStatus) and \ - runner_id not in self.assigned_runners: - other_node_in_instance_has_failed= True - - if other_node_in_instance_has_failed: - # Spin down *our* runner - return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id]) - - # If we are failed - and *all of the other nodes have spun down* - then we can spin down too. - for _instance_id, instance in state.instances.items(): - if self.node_id in instance.shard_assignments.node_to_runner and \ - instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \ - instance.shard_assignments.node_to_runner[self.node_id] in self.assigned_runners and \ - isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, FailedRunnerStatus): - - num_spundown_nodes = 0 - for runner_id in instance.shard_assignments.runner_to_shard: - if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \ - runner_id not in self.assigned_runners: - num_spundown_nodes += 1 - # Suggested: - # if runner_id in state.runners and isinstance(state.runners[runner_id], ReadyRunnerStatus): - # if runner_id != instance.shard_assignments.node_to_runner[self.node_id]: - # num_spundown_nodes += 1 - - if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1: - # All the other nodes are spun down - so now we can spin down too. - # This also catches the case of 1-node. If there's one node in the instance then we should spin down straight away - return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id]) - - # Then assign runners we do want - for instance_id, instance in state.instances.items(): - for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): - if node_id != self.node_id: - continue - - if runner_id not in self.assigned_runners: - return AssignRunnerOp( - runner_id=runner_id, - instance_id=instance_id, - shard_metadata=instance.shard_assignments.runner_to_shard[runner_id], - hosts=instance.hosts - ) - - # Then make sure things are downloading. - for instance_id, instance in state.instances.items(): - # We should already have asserted that this runner exists - # If it didn't exist then we return a assign_runner op. - for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): - if node_id != self.node_id: - continue - assert runner_id in self.assigned_runners - - runner = self.assigned_runners[runner_id] - - if not runner.is_downloaded: - if runner.status.runner_status == RunnerStatusType.Downloading: # Forward compatibility - # TODO: If failed status then we retry - return None - else: - return DownloadOp( - runner_id=runner_id, - instance_id=instance_id, - shard_metadata=instance.shard_assignments.runner_to_shard[runner_id], - hosts=instance.hosts - ) - - # Then spin up 'ready' runners that should be active - for _instance_id, instance in state.instances.items(): - if self.node_id in instance.shard_assignments.node_to_runner and \ - self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].runner is None and \ - instance.instance_type == InstanceStatus.ACTIVE: - - # We are part of this instance, we want it up but it hasn't been spun up yet. - # Need to assert all other runners are ready before we can spin up. - ready_to_spin = True - for runner_id in instance.shard_assignments.node_to_runner.values(): - if runner_id in state.runners and state.runners[runner_id].runner_status != RunnerStatusType.Ready: - ready_to_spin = False - - if ready_to_spin: - return RunnerUpOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id]) - - # Then make sure things are running based on tasks. - for instance_id, instance in state.instances.items(): - for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): - if node_id != self.node_id: - continue - assert runner_id in self.assigned_runners - runner = self.assigned_runners[runner_id] - if runner.status.runner_status != RunnerStatusType.Loaded: - continue # The only previous state to get to Running is from Loaded - - for _, task in state.tasks.items(): - if task.instance_id == instance_id and ( - task.task_status == TaskStatus.PENDING or task.task_status == TaskStatus.FAILED - ): - if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1): - return ExecuteTaskOp(runner_id=runner_id, task=task) - else: - # We already know our own status is Loaded. We are rank 0, - # so let's check that all the other runners are running - ready for us to fire the prompt. - running_runner_count = 0 - for other_runner_id, other_runner_status in state.runners.items(): - if other_runner_id in instance.shard_assignments.node_to_runner.values() and \ - isinstance(other_runner_status, RunningRunnerStatus): - running_runner_count += 1 - - if running_runner_count == runner.shard_metadata.world_size - 1: - return ExecuteTaskOp(runner_id=runner_id, task=task) - - return None - - - async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]: - if runner_id in self.assigned_runners: - assigned_runner = self.assigned_runners[runner_id] - - assigned_runner.runner = None - assigned_runner.status = FailedRunnerStatus(error_message=str(e)) - assigned_runner.failures.append( - ( - time.time(), - e - ) - ) - - # Reset failure count back to 0 when succesful - if len(assigned_runner.failures) >= 3: - # Too many retries. We will emit a DeleteInstance - yield InstanceDeleted( - instance_id=assigned_runner.instance_id - ) - - yield assigned_runner.status_update_event() - - - async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]: - if runner_id in self.assigned_runners: - yield TaskStateUpdated( - task_id=task_id, - task_status=TaskStatus.FAILED, - ) - - yield TaskFailed( - task_id=task_id, - error_type=str(type(e)), - error_message=str(e) - ) - - async for event in self._fail_runner(e, runner_id): - yield event - - - async def event_publisher(self, event: Event) -> None: - assert self.worker_events is not None - await self.worker_events.append_events([event], self.node_id) - self.logger.info(f"published event: {event}") - - # Handle state updates - async def run(self): - assert self.global_events is not None +async def run(worker_state: Worker): + assert worker_state.global_events is not None while True: # 1. get latest events - events = await self.global_events.get_events_since(self.state.last_event_applied_idx) + events = await worker_state.global_events.get_events_since(worker_state.state.last_event_applied_idx) # 2. for each event, apply it to the state and run sagas for event_from_log in events: - self.state = apply(self.state, event_from_log) + worker_state.state = apply(worker_state.state, event_from_log) # 3. based on the updated state, we plan & execute an operation. - op: RunnerOp | None = self.plan(self.state) + op: RunnerOp | None = plan( + worker_state.assigned_runners, + worker_state.node_id, + worker_state.state.instances, + worker_state.state.runners, + worker_state.state.tasks, + ) if op is not None: - self.logger.info(f"!!! plan result: {op}") + worker_state.logger.info(f"!!! plan result: {op}") # run the op, synchronously blocking for now if op is not None: - try: - async for event in self._execute_op(op): - await self.event_publisher(event) - except Exception as e: - # execeute_task_op already has its own exception handling here. So we assume we had an exception in one of the other op types. - # we therefore just fail the runner. - self.logger.warning(f"Encountered exception when executing worker op {op}: {e}. \n Runner will be spun down and retried.") - async for event in self._fail_runner( - e, - runner_id=op.runner_id, - ): - await self.event_publisher(event) + async for event in worker_state.execute_op(op): + await worker_state.event_publisher(event) await asyncio.sleep(0.01) - if len(events) > 0: - self.logger.info(f"state: {self.state}") + async def main(): @@ -678,7 +72,7 @@ async def main(): worker = Worker(node_id, logger, shard_downloader, event_log_manager.worker_events, event_log_manager.global_events) - await worker.run() + await run(worker) if __name__ == "__main__": asyncio.run(main()) diff --git a/worker/plan.py b/worker/plan.py new file mode 100644 index 00000000..4d644023 --- /dev/null +++ b/worker/plan.py @@ -0,0 +1,205 @@ +from typing import Mapping + +from shared.types.common import NodeId +from shared.types.events import ( + InstanceId, +) +from shared.types.tasks import Task, TaskId, TaskStatus +from shared.types.worker.common import RunnerId +from shared.types.worker.instances import Instance, InstanceStatus +from shared.types.worker.ops import ( + AssignRunnerOp, + ExecuteTaskOp, + RunnerDownOp, + RunnerFailedOp, + RunnerOp, + RunnerUpOp, + UnassignRunnerOp, +) +from shared.types.worker.runners import ( + DownloadingRunnerStatus, + FailedRunnerStatus, + InactiveRunnerStatus, + LoadedRunnerStatus, + RunnerStatus, + RunnerStatusType, + RunningRunnerStatus, +) +from worker.common import AssignedRunner + + +def unassign_runners(instances: Mapping[InstanceId, Instance], state_runners: Mapping[RunnerId, RunnerStatus], assigned_runners: dict[RunnerId, AssignedRunner]) -> UnassignRunnerOp | None: + runner_ids: set[RunnerId] = { + runner_id + for instance in instances.values() + for runner_id in instance.shard_assignments.runner_to_shard + } + for runner_id, _ in assigned_runners.items(): + if runner_id not in runner_ids: + return UnassignRunnerOp(runner_id=runner_id) + + # If our instance is in 'downloading' or 'assigned' state, then we know the runner is stale. These are part of AssignRunnerOp and should be blocking. + for assigned_runner_id in assigned_runners: + if assigned_runner_id in state_runners and \ + isinstance(state_runners[assigned_runner_id], DownloadingRunnerStatus): + return UnassignRunnerOp(runner_id=assigned_runner_id) + + return None + +def failed_runners(assigned_runners: dict[RunnerId, AssignedRunner]) -> RunnerFailedOp | None: + for runner_id, assigned_runner in assigned_runners.items(): + if assigned_runner.runner is not None and \ + not assigned_runner.runner.healthy and \ + not isinstance(assigned_runner.status, FailedRunnerStatus): + return RunnerFailedOp(runner_id=runner_id) + return None + +def spin_down_runners( + instances: Mapping[InstanceId, Instance], + assigned_runners: dict[RunnerId, AssignedRunner], + state_runners: Mapping[RunnerId, RunnerStatus], + worker_node_id: NodeId) -> RunnerDownOp | None: + for _instance_id, instance in instances.items(): + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): + if node_id != worker_node_id: + continue + + # We spin down a runner if it's meant to be inactive and it's Loaded. + if runner_id in assigned_runners and \ + isinstance(assigned_runners[runner_id].status, LoadedRunnerStatus) and \ + instance.instance_type == InstanceStatus.INACTIVE: + return RunnerDownOp(runner_id=runner_id) + + # If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down + for _instance_id, instance in instances.items(): + if worker_node_id in instance.shard_assignments.node_to_runner and \ + instance.shard_assignments.node_to_runner[worker_node_id] in assigned_runners and \ + not isinstance(assigned_runners[instance.shard_assignments.node_to_runner[worker_node_id]].status, InactiveRunnerStatus): # make sure that our runner has not already been spun down into ready state + other_node_in_instance_has_failed = False + for runner_id in instance.shard_assignments.runner_to_shard: + if runner_id in state_runners and \ + isinstance(state_runners[runner_id], FailedRunnerStatus) and \ + runner_id not in assigned_runners: + other_node_in_instance_has_failed= True + + if other_node_in_instance_has_failed: + # Spin down *our* runner + return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[worker_node_id]) + + # If we are failed - and *all of the other nodes have spun down* - then we can spin down too. + for _instance_id, instance in instances.items(): + if worker_node_id in instance.shard_assignments.node_to_runner and \ + instance.shard_assignments.node_to_runner[worker_node_id] in state_runners and \ + instance.shard_assignments.node_to_runner[worker_node_id] in assigned_runners and \ + isinstance(assigned_runners[instance.shard_assignments.node_to_runner[worker_node_id]].status, FailedRunnerStatus): + + num_spundown_nodes = 0 + for runner_id in instance.shard_assignments.runner_to_shard: + if isinstance(state_runners[runner_id], InactiveRunnerStatus) and \ + runner_id not in assigned_runners: + num_spundown_nodes += 1 + # Suggested: + # if runner_id in state_runners and isinstance(state.runners[runner_id], InactiveRunnerStatus): + # if runner_id != instance.shard_assignments.node_to_runner[worker_node_id]: + # num_spundown_nodes += 1 + + if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1: + # All the other nodes are spun down - so now we can spin down too. + # This also catches the case of 1-node. If there's one node in the instance then we should spin down straight away + return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[worker_node_id]) + return None + +def assign_runners(instances: Mapping[InstanceId, Instance], assigned_runners: dict[RunnerId, AssignedRunner], worker_node_id: NodeId) -> AssignRunnerOp | None: + for instance_id, instance in instances.items(): + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): + if node_id != worker_node_id: + continue + + if runner_id not in assigned_runners: + return AssignRunnerOp( + runner_id=runner_id, + instance_id=instance_id, + shard_metadata=instance.shard_assignments.runner_to_shard[runner_id], + hosts=instance.hosts + ) + return None + +def spin_up_runners(instances: Mapping[InstanceId, Instance], assigned_runners: dict[RunnerId, AssignedRunner], state_runners: Mapping[RunnerId, RunnerStatus], worker_node_id: NodeId) -> RunnerUpOp | None: + for _instance_id, instance in instances.items(): + if worker_node_id in instance.shard_assignments.node_to_runner and \ + assigned_runners[instance.shard_assignments.node_to_runner[worker_node_id]].runner is None and \ + instance.instance_type == InstanceStatus.ACTIVE: + + # We are part of this instance, we want it up but it hasn't been spun up yet. + # Need to assert all other runners are ready before we can spin up. + ready_to_spin = True + for runner_id in instance.shard_assignments.node_to_runner.values(): + if runner_id in state_runners and state_runners[runner_id].runner_status != RunnerStatusType.Inactive: + ready_to_spin = False + + if ready_to_spin: + return RunnerUpOp(runner_id=instance.shard_assignments.node_to_runner[worker_node_id]) + return None + +def execute_task_op(instances: Mapping[InstanceId, Instance], assigned_runners: dict[RunnerId, AssignedRunner], state_runners: Mapping[RunnerId, RunnerStatus], tasks: Mapping[TaskId, Task], worker_node_id: NodeId) -> ExecuteTaskOp | None: + for instance_id, instance in instances.items(): + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): + if node_id != worker_node_id: + continue + assert runner_id in assigned_runners + runner = assigned_runners[runner_id] + if runner.status.runner_status != RunnerStatusType.Loaded: + continue # The only previous state to get to Running is from Loaded + + for _, task in tasks.items(): + if task.instance_id == instance_id and ( + task.task_status == TaskStatus.PENDING or task.task_status == TaskStatus.FAILED + ): + if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1): + return ExecuteTaskOp(runner_id=runner_id, task=task) + else: + # We already know our own status is Loaded. We are rank 0, + # so let's check that all the other runners are running - ready for us to fire the prompt. + running_runner_count = 0 + for other_runner_id, other_runner_status in state_runners.items(): + if other_runner_id in instance.shard_assignments.node_to_runner.values() and \ + isinstance(other_runner_status, RunningRunnerStatus): + running_runner_count += 1 + + if running_runner_count == runner.shard_metadata.world_size - 1: + return ExecuteTaskOp(runner_id=runner_id, task=task) + + return None + + + +def plan(assigned_runners: dict[RunnerId, AssignedRunner], + worker_node_id: NodeId, + instances: Mapping[InstanceId, Instance], + state_runners: Mapping[RunnerId, RunnerStatus], # all global + tasks: Mapping[TaskId, Task]) -> RunnerOp | None: + # First, unassign assigned runners that are no longer in the state. + if unop := unassign_runners(instances, state_runners, assigned_runners): + return unop + + # mark failed runners that are not marked yet as failed + if failed_op := failed_runners(assigned_runners): + return failed_op + + # spin down runners that are no longer needed + if down_op := spin_down_runners(instances, assigned_runners, state_runners, worker_node_id): + return down_op + + # Then assign runners we do want + if assign_op := assign_runners(instances, assigned_runners, worker_node_id): + return assign_op + + # Then spin up 'ready' runners that should be active + if runner_up_op := spin_up_runners(instances, assigned_runners, state_runners, worker_node_id): + return runner_up_op + + # Then make sure things are running based on tasks. + if exec_op := execute_task_op(instances, assigned_runners, state_runners, tasks, worker_node_id): + return exec_op + + return None diff --git a/worker/runner/communication.py b/worker/runner/communication.py index 85efa090..58104724 100644 --- a/worker/runner/communication.py +++ b/worker/runner/communication.py @@ -62,7 +62,7 @@ async def supervisor_read_response( assert proc.stdout is not None, ( "proc.stdout should not be None when created with stdout=PIPE" ) - line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=10) + line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=180) line: str = line_bytes.decode("utf-8").strip() if not line: diff --git a/worker/tests/conftest.py b/worker/tests/conftest.py index 2548fd05..7e31606f 100644 --- a/worker/tests/conftest.py +++ b/worker/tests/conftest.py @@ -1,36 +1,46 @@ -import asyncio from ipaddress import IPv4Address from logging import Logger, getLogger -from pathlib import Path -from typing import Awaitable, Callable +from typing import Callable, Optional import pytest -from shared.db.sqlite.connector import AsyncSQLiteEventStorage -from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager from shared.models.model_meta import get_model_meta from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams -from shared.types.common import CommandId, Host, NodeId +from shared.types.common import Host, NodeId from shared.types.models import ModelId, ModelMetadata -from shared.types.state import State from shared.types.tasks import ( ChatCompletionTask, TaskId, TaskStatus, TaskType, ) -from shared.types.worker.common import InstanceId, NodeStatus +from shared.types.worker.common import InstanceId from shared.types.worker.instances import Instance, InstanceStatus -from shared.types.worker.ops import ( - AssignRunnerOp, - RunnerUpOp, -) from shared.types.worker.runners import RunnerId, ShardAssignments from shared.types.worker.shards import PipelineShardMetadata -from worker.download.shard_downloader import NoopShardDownloader -from worker.main import Worker +from worker.tests.constants import ( + COMMAND_1_ID, + INSTANCE_1_ID, + MODEL_A_ID, + NODE_A, + RUNNER_1_ID, + TASK_1_ID, +) +@pytest.fixture +def user_message(): + """Override this fixture in tests to customize the message""" + return "Hello, how are you?" + +@pytest.fixture +def logger() -> Logger: + return getLogger("test_logger") + +@pytest.fixture +async def model_meta() -> ModelMetadata: + return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') + @pytest.fixture def hosts(): def _hosts(count: int, offset: int = 0) -> list[Host]: @@ -44,29 +54,8 @@ def hosts(): return _hosts - @pytest.fixture -def hosts_one(hosts: Callable[[int], list[Host]]): - return hosts(1) - - -@pytest.fixture -def hosts_two(hosts: Callable[[int], list[Host]]): - return hosts(2) - - -@pytest.fixture -def user_message(): - """Override this fixture in tests to customize the message""" - return "Hello, how are you?" - -@pytest.fixture -async def model_meta() -> ModelMetadata: - return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') - - -@pytest.fixture -def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[int, int], PipelineShardMetadata]: +def pipeline_shard_meta(model_meta: ModelMetadata) -> Callable[[int, int], PipelineShardMetadata]: def _pipeline_shard_meta( num_nodes: int = 1, device_rank: int = 0 ) -> PipelineShardMetadata: @@ -90,6 +79,37 @@ def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[ return _pipeline_shard_meta +@pytest.fixture +def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]]): + from typing import Optional + + def _instance( + instance_id: Optional[InstanceId] = None, + node_id: Optional[NodeId] = None, + runner_id: Optional[RunnerId] = None, + model_id: Optional[ModelId] = None, + ) -> Instance: + resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID + resolved_node_id = node_id if node_id is not None else NODE_A + resolved_runner_id = runner_id if runner_id is not None else RUNNER_1_ID + resolved_model_id = model_id if model_id is not None else MODEL_A_ID + + shard_assignments = ShardAssignments( + model_id=resolved_model_id, + runner_to_shard={ + resolved_runner_id: pipeline_shard_meta(1, 0) + }, + node_to_runner={resolved_node_id: resolved_runner_id} + ) + + return Instance( + instance_id=resolved_instance_id, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(1) + ) + return _instance + @pytest.fixture def completion_create_params(user_message: str) -> ChatCompletionTaskParams: """Creates ChatCompletionParams with the given message""" @@ -101,10 +121,14 @@ def completion_create_params(user_message: str) -> ChatCompletionTaskParams: @pytest.fixture def chat_completion_task(completion_create_params: ChatCompletionTaskParams): - def _chat_completion_task(instance_id: InstanceId, task_id: TaskId) -> ChatCompletionTask: + def _chat_completion_task(instance_id: Optional[InstanceId] = None, task_id: Optional[TaskId] = None) -> ChatCompletionTask: + if instance_id is None: + instance_id = INSTANCE_1_ID + if task_id is None: + task_id = TASK_1_ID return ChatCompletionTask( task_id=task_id, - command_id=CommandId(), + command_id=COMMAND_1_ID, instance_id=instance_id, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, @@ -112,105 +136,4 @@ def chat_completion_task(completion_create_params: ChatCompletionTaskParams): ) return _chat_completion_task -@pytest.fixture -def node_id() -> NodeId: - """Shared node ID for tests""" - return NodeId() -@pytest.fixture -def state(node_id: NodeId): - node_status={ - node_id: NodeStatus.Idle - } - - return State( - node_status=node_status, - ) - -@pytest.fixture -def logger() -> Logger: - return getLogger("test_logger") - -@pytest.fixture -def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts_one: list[Host]): - def _instance(instance_id: InstanceId, node_id: NodeId, runner_id: RunnerId) -> Instance: - model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - runner_id: pipeline_shard_meta(1, 0) - }, - node_to_runner={node_id: runner_id} - ) - - return Instance( - instance_id=instance_id, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts_one - ) - return _instance - -@pytest.fixture -async def worker(node_id: NodeId, logger: Logger): - event_log_manager = EventLogManager(EventLogConfig(), logger) - shard_downloader = NoopShardDownloader() - await event_log_manager.initialize() - - return Worker(node_id, logger, shard_downloader, worker_events=event_log_manager.global_events, global_events=event_log_manager.global_events) - -@pytest.fixture -async def worker_with_assigned_runner(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]): - """Fixture that provides a worker with an already assigned runner.""" - - instance_obj: Instance = instance(InstanceId(), worker.node_id, RunnerId()) - - # Extract runner_id from shard assignments - runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard)) - - # Assign the runner - assign_op = AssignRunnerOp( - runner_id=runner_id, - shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], - hosts=instance_obj.hosts, - instance_id=instance_obj.instance_id, - ) - - async for _ in worker._execute_op(assign_op): # type: ignore[misc] - pass - - return worker, runner_id, instance_obj - -@pytest.fixture -async def worker_with_running_runner(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance]): - """Fixture that provides a worker with an already assigned runner.""" - worker, runner_id, instance_obj = worker_with_assigned_runner - - runner_up_op = RunnerUpOp(runner_id=runner_id) - async for _ in worker._execute_op(runner_up_op): # type: ignore[misc] - pass - - # Is the runner actually running? - supervisor = next(iter(worker.assigned_runners.values())).runner - assert supervisor is not None - assert supervisor.healthy - - return worker, runner_id, instance_obj - -@pytest.fixture -def worker_running(logger: Logger) -> Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]]: - async def _worker_running(node_id: NodeId) -> tuple[Worker, AsyncSQLiteEventStorage]: - event_log_manager = EventLogManager(EventLogConfig(), logger) - await event_log_manager.initialize() - - global_events = event_log_manager.global_events - await global_events.delete_all_events() - - shard_downloader = NoopShardDownloader() - worker = Worker(node_id, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) - asyncio.create_task(worker.run()) - - return worker, global_events - - return _worker_running \ No newline at end of file diff --git a/worker/tests/constants.py b/worker/tests/constants.py new file mode 100644 index 00000000..8e139a13 --- /dev/null +++ b/worker/tests/constants.py @@ -0,0 +1,26 @@ +from typing import Final + +from shared.types.common import CommandId, NodeId +from shared.types.models import ModelId +from shared.types.tasks import TaskId +from shared.types.worker.common import InstanceId, RunnerId + +MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") + +NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") +NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") + +RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111") +RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333") + +INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222") +INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444") + +MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' +MODEL_B_ID: Final[ModelId] = 'mlx-community/TinyLlama-1.1B-Chat-v1.0' + +TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555") +TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666") + +COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777") +COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888") \ No newline at end of file diff --git a/worker/tests/test_download.py b/worker/tests/test_download.py index a201f528..c44d6e65 100644 --- a/worker/tests/test_download.py +++ b/worker/tests/test_download.py @@ -8,6 +8,7 @@ from worker.download.impl_shard_downloader import exo_shard_downloader from worker.download.shard_downloader import ShardDownloader +@pytest.mark.slow @pytest.mark.asyncio async def test_shard_downloader(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata]): shard_downloader: ShardDownloader = exo_shard_downloader() diff --git a/worker/tests/test_handlers/conftest.py b/worker/tests/test_handlers/conftest.py new file mode 100644 index 00000000..9f7801c6 --- /dev/null +++ b/worker/tests/test_handlers/conftest.py @@ -0,0 +1,70 @@ +from logging import Logger +from typing import Callable + +import pytest + +from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager +from shared.types.common import NodeId +from shared.types.worker.common import InstanceId +from shared.types.worker.instances import Instance +from shared.types.worker.ops import ( + AssignRunnerOp, + RunnerUpOp, +) +from shared.types.worker.runners import RunnerId +from worker.download.shard_downloader import NoopShardDownloader +from worker.tests.constants import INSTANCE_1_ID, NODE_A, RUNNER_1_ID +from worker.worker import Worker + + +@pytest.fixture +def user_message(): + return "What, according to Douglas Adams, is the meaning of life, the universe and everything?" + + +@pytest.fixture +async def worker(logger: Logger): + event_log_manager = EventLogManager(EventLogConfig(), logger) + shard_downloader = NoopShardDownloader() + await event_log_manager.initialize() + + return Worker(NODE_A, logger, shard_downloader, worker_events=event_log_manager.global_events, global_events=event_log_manager.global_events) + +# TODO: instance_id and runner_id are selectable. +@pytest.fixture +async def worker_with_assigned_runner(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]): + """Fixture that provides a worker with an already assigned runner.""" + + instance_id = INSTANCE_1_ID + runner_id = RUNNER_1_ID + instance_obj: Instance = instance(instance_id, worker.node_id, runner_id) + + # Assign the runner + assign_op = AssignRunnerOp( + runner_id=runner_id, + shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], + hosts=instance_obj.hosts, + instance_id=instance_obj.instance_id, + ) + + async for _ in worker.execute_op(assign_op): + pass + + return worker, instance_obj + +@pytest.fixture +async def worker_with_running_runner(worker_with_assigned_runner: tuple[Worker, Instance]): + """Fixture that provides a worker with an already assigned runner.""" + worker, instance_obj = worker_with_assigned_runner + + runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) + async for _ in worker.execute_op(runner_up_op): + pass + + # Is the runner actually running? + supervisor = next(iter(worker.assigned_runners.values())).runner + assert supervisor is not None + assert supervisor.healthy + + return worker, instance_obj + diff --git a/worker/tests/test_handlers/test_handlers_happy.py b/worker/tests/test_handlers/test_handlers_happy.py new file mode 100644 index 00000000..5d2dc0b8 --- /dev/null +++ b/worker/tests/test_handlers/test_handlers_happy.py @@ -0,0 +1,159 @@ +from typing import Callable + +import pytest + +from shared.types.common import NodeId +from shared.types.events import ( + ChunkGenerated, + RunnerDeleted, + RunnerStatusUpdated, + TaskStateUpdated, +) +from shared.types.events.chunks import TokenChunk +from shared.types.tasks import ChatCompletionTask, TaskStatus +from shared.types.worker.common import RunnerId +from shared.types.worker.instances import Instance, InstanceId +from shared.types.worker.ops import ( + AssignRunnerOp, + ExecuteTaskOp, + RunnerDownOp, + RunnerUpOp, + UnassignRunnerOp, +) +from shared.types.worker.runners import ( + DownloadingRunnerStatus, + InactiveRunnerStatus, + LoadedRunnerStatus, + RunningRunnerStatus, +) +from worker.main import Worker +from worker.tests.constants import ( + RUNNER_1_ID, +) +from worker.tests.test_handlers.utils import read_events_op + + +@pytest.mark.asyncio +async def test_assign_op(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]): + instance_obj: Instance = instance(InstanceId(), worker.node_id, RUNNER_1_ID) + + assign_op = AssignRunnerOp( + runner_id=RUNNER_1_ID, + shard_metadata=instance_obj.shard_assignments.runner_to_shard[RUNNER_1_ID], + hosts=instance_obj.hosts, + instance_id=instance_obj.instance_id, + ) + + events = await read_events_op(worker, assign_op) + + # We should have a status update saying 'starting'. + assert len(events) == 2 + assert isinstance(events[0], RunnerStatusUpdated) + assert isinstance(events[0].runner_status, DownloadingRunnerStatus) + assert isinstance(events[1], RunnerStatusUpdated) + assert isinstance(events[1].runner_status, InactiveRunnerStatus) + + # And the runner should be assigned + assert RUNNER_1_ID in worker.assigned_runners + assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus) + +@pytest.mark.asyncio +async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, Instance]): + worker, _ = worker_with_assigned_runner + + unassign_op = UnassignRunnerOp( + runner_id=RUNNER_1_ID + ) + + events = await read_events_op(worker, unassign_op) + + # We should have no assigned runners and no events were emitted + assert len(worker.assigned_runners) == 0 + assert len(events) == 1 + assert isinstance(events[0], RunnerDeleted) + +@pytest.mark.asyncio +async def test_runner_up_op( + worker_with_assigned_runner: tuple[Worker, Instance], + chat_completion_task: Callable[[], ChatCompletionTask], + ): + worker, _ = worker_with_assigned_runner + + runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) + + events = await read_events_op(worker, runner_up_op) + + assert len(events) == 1 + assert isinstance(events[0], RunnerStatusUpdated) + assert isinstance(events[0].runner_status, LoadedRunnerStatus) + + # Is the runner actually running? + supervisor = next(iter(worker.assigned_runners.values())).runner + assert supervisor is not None + assert supervisor.healthy + + full_response = '' + + async for chunk in supervisor.stream_response(task=chat_completion_task()): + if isinstance(chunk, TokenChunk): + full_response += chunk.text + + assert "42" in full_response.lower(), ( + f"Expected '42' in response, but got: {full_response}" + ) + + runner = worker.assigned_runners[RUNNER_1_ID].runner + assert runner is not None + await runner.astop() # Neat cleanup. + +@pytest.mark.asyncio +async def test_runner_down_op(worker_with_running_runner: tuple[Worker, Instance]): + worker, _ = worker_with_running_runner + + runner_down_op = RunnerDownOp(runner_id=RUNNER_1_ID) + events = await read_events_op(worker, runner_down_op) + + assert len(events) == 1 + assert isinstance(events[0], RunnerStatusUpdated) + assert isinstance(events[0].runner_status, InactiveRunnerStatus) + +@pytest.mark.asyncio +async def test_execute_task_op( + worker_with_running_runner: tuple[Worker, Instance], + chat_completion_task: Callable[[], ChatCompletionTask]): + worker, _ = worker_with_running_runner + + execute_task_op = ExecuteTaskOp( + runner_id=RUNNER_1_ID, + task=chat_completion_task() + ) + + events = await read_events_op(worker, execute_task_op) + + assert len(events) > 20 + + print(f'{events=}') + + + assert isinstance(events[0], RunnerStatusUpdated) + assert isinstance(events[0].runner_status, RunningRunnerStatus) + + assert isinstance(events[1], TaskStateUpdated) + assert events[1].task_status == TaskStatus.RUNNING # It tried to start. + + assert isinstance(events[-2], TaskStateUpdated) + assert events[-2].task_status == TaskStatus.COMPLETE # It tried to start. + + assert isinstance(events[-1], RunnerStatusUpdated) + assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed. + + gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)] + text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)] + assert len(text_chunks) == len(events) - 4 + + output_text = ''.join([x.text for x in text_chunks]) + assert '42' in output_text + + runner = worker.assigned_runners[RUNNER_1_ID].runner + assert runner is not None + await runner.astop() # Neat cleanup. diff --git a/worker/tests/test_handlers/test_handlers_sad.py b/worker/tests/test_handlers/test_handlers_sad.py new file mode 100644 index 00000000..05238c8e --- /dev/null +++ b/worker/tests/test_handlers/test_handlers_sad.py @@ -0,0 +1,61 @@ +## Tests for worker state handlers + +from typing import Callable + +import pytest + +from shared.types.events import ( + RunnerStatusUpdated, + TaskFailed, + TaskStateUpdated, +) +from shared.types.tasks import ChatCompletionTask, TaskStatus +from shared.types.worker.instances import Instance +from shared.types.worker.ops import ( + ExecuteTaskOp, +) +from shared.types.worker.runners import ( + FailedRunnerStatus, + RunningRunnerStatus, +) +from worker.main import Worker +from worker.tests.constants import RUNNER_1_ID +from worker.tests.test_handlers.utils import read_events_op + + +@pytest.mark.asyncio +async def test_execute_task_fails( + worker_with_running_runner: tuple[Worker, Instance], + chat_completion_task: Callable[[], ChatCompletionTask]): + worker, _ = worker_with_running_runner + + task = chat_completion_task() + messages = task.task_params.messages + messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL' + + execute_task_op = ExecuteTaskOp( + runner_id=RUNNER_1_ID, + task=task + ) + + events = await read_events_op(worker, execute_task_op) + + assert len(events) == 5 + + print(events) + + assert isinstance(events[0], RunnerStatusUpdated) + assert isinstance(events[0].runner_status, RunningRunnerStatus) # It tried to start. + + assert isinstance(events[1], TaskStateUpdated) + assert events[1].task_status == TaskStatus.RUNNING # It tried to start. + + assert isinstance(events[2], TaskStateUpdated) + assert events[2].task_status == TaskStatus.FAILED # Task marked as failed. + + assert isinstance(events[3], TaskFailed) + + assert isinstance(events[4], RunnerStatusUpdated) + assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed. + +# TODO: Much more to do here! \ No newline at end of file diff --git a/worker/tests/test_handlers/utils.py b/worker/tests/test_handlers/utils.py new file mode 100644 index 00000000..8e97949b --- /dev/null +++ b/worker/tests/test_handlers/utils.py @@ -0,0 +1,18 @@ +## Tests for worker state handlers + + + +from shared.types.events import ( + Event, +) +from shared.types.worker.ops import ( + RunnerOp, +) +from worker.main import Worker + + +async def read_events_op(worker: Worker, op: RunnerOp) -> list[Event]: + events: list[Event] = [] + async for event in worker.execute_op(op): + events.append(event) + return events \ No newline at end of file diff --git a/worker/tests/test_integration/conftest.py b/worker/tests/test_integration/conftest.py new file mode 100644 index 00000000..8e3faa39 --- /dev/null +++ b/worker/tests/test_integration/conftest.py @@ -0,0 +1,36 @@ +import asyncio +from logging import Logger +from typing import Awaitable, Callable + +import pytest + +from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager +from shared.types.common import NodeId +from worker.download.shard_downloader import NoopShardDownloader +from worker.main import run +from worker.worker import Worker + + +@pytest.fixture +def user_message(): + """Override this fixture in tests to customize the message""" + return "What is the capital of Japan?" + + +@pytest.fixture +def worker_running(logger: Logger) -> Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]]: + async def _worker_running(node_id: NodeId) -> tuple[Worker, AsyncSQLiteEventStorage]: + event_log_manager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + shard_downloader = NoopShardDownloader() + worker = Worker(node_id, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) + asyncio.create_task(run(worker)) + + return worker, global_events + + return _worker_running \ No newline at end of file diff --git a/worker/tests/test_worker_integration_utils.py b/worker/tests/test_integration/integration_utils.py similarity index 100% rename from worker/tests/test_worker_integration_utils.py rename to worker/tests/test_integration/integration_utils.py diff --git a/worker/tests/test_worker_integration.py b/worker/tests/test_integration/test_creation.py similarity index 53% rename from worker/tests/test_worker_integration.py rename to worker/tests/test_integration/test_creation.py index 99f8ed05..4e13a18b 100644 --- a/worker/tests/test_worker_integration.py +++ b/worker/tests/test_integration/test_creation.py @@ -1,14 +1,11 @@ import asyncio from logging import Logger -from typing import Awaitable, Callable, Final - -import pytest +from typing import Awaitable, Callable # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py from shared.db.sqlite.connector import AsyncSQLiteEventStorage from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams -from shared.types.common import CommandId, Host, NodeId +from shared.types.common import Host, NodeId from shared.types.events import ( InstanceCreated, InstanceDeleted, @@ -18,7 +15,7 @@ from shared.types.events import ( ) from shared.types.events.chunks import TokenChunk from shared.types.models import ModelId -from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType +from shared.types.tasks import Task, TaskId from shared.types.worker.common import InstanceId, RunnerId from shared.types.worker.instances import ( Instance, @@ -26,35 +23,31 @@ from shared.types.worker.instances import ( ShardAssignments, ) from shared.types.worker.runners import ( - AssignedRunnerStatus, DownloadingRunnerStatus, # RunningRunnerStatus, FailedRunnerStatus, + InactiveRunnerStatus, LoadedRunnerStatus, - ReadyRunnerStatus, ) from shared.types.worker.shards import PipelineShardMetadata +from worker.common import AssignedRunner from worker.download.shard_downloader import NoopShardDownloader -from worker.main import AssignedRunner, Worker -from worker.tests.test_worker_integration_utils import read_streaming_response +from worker.main import run +from worker.tests.constants import ( + INSTANCE_1_ID, + MASTER_NODE_ID, + NODE_A, + NODE_B, + RUNNER_1_ID, + RUNNER_2_ID, + TASK_1_ID, + TASK_2_ID, +) +from worker.tests.test_integration.integration_utils import ( + read_streaming_response, +) +from worker.worker import Worker -MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") -NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") -NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") - -# Define constant IDs for deterministic test cases -RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111") -INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222") -RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333") -INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444") -MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' -MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' -TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555") -TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666") - -@pytest.fixture -def user_message(): - return "What is the capital of Japan?" async def test_runner_assigned( worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], @@ -63,8 +56,6 @@ async def test_runner_assigned( worker, global_events = await worker_running(NODE_A) - print(worker) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.INACTIVE @@ -82,22 +73,19 @@ async def test_runner_assigned( # Ensure the worker has taken the correct action assert len(worker.assigned_runners) == 1 assert RUNNER_1_ID in worker.assigned_runners - assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, ReadyRunnerStatus) + assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus) # Ensure the correct events have been emitted events = await global_events.get_events_since(0) - print(events) - assert len(events) >= 4 # len(events) is 4 if it's already downloaded. It is > 4 if there have to be download events. + assert len(events) >= 3 # len(events) is 4 if it's already downloaded. It is > 4 if there have to be download events. assert isinstance(events[1].event, RunnerStatusUpdated) - assert isinstance(events[1].event.runner_status, AssignedRunnerStatus) - assert isinstance(events[2].event, RunnerStatusUpdated) - assert isinstance(events[2].event.runner_status, DownloadingRunnerStatus) + assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus) assert isinstance(events[-1].event, RunnerStatusUpdated) - assert isinstance(events[-1].event.runner_status, ReadyRunnerStatus) + assert isinstance(events[-1].event.runner_status, InactiveRunnerStatus) # Ensure state is correct - assert isinstance(worker.state.runners[RUNNER_1_ID], ReadyRunnerStatus) + assert isinstance(worker.state.runners[RUNNER_1_ID], InactiveRunnerStatus) async def test_runner_assigned_active( worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], @@ -118,7 +106,7 @@ async def test_runner_assigned_active( origin=MASTER_NODE_ID ) - await asyncio.sleep(1.0) + await asyncio.sleep(2.0) assert len(worker.assigned_runners) == 1 assert RUNNER_1_ID in worker.assigned_runners @@ -126,13 +114,11 @@ async def test_runner_assigned_active( # Ensure the correct events have been emitted events = await global_events.get_events_since(0) - assert len(events) >= 5 # len(events) is 5 if it's already downloaded. It is > 5 if there have to be download events. + assert len(events) >= 4 # len(events) is 5 if it's already downloaded. It is > 5 if there have to be download events. assert isinstance(events[1].event, RunnerStatusUpdated) - assert isinstance(events[1].event.runner_status, AssignedRunnerStatus) - assert isinstance(events[2].event, RunnerStatusUpdated) - assert isinstance(events[2].event.runner_status, DownloadingRunnerStatus) + assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus) assert isinstance(events[-2].event, RunnerStatusUpdated) - assert isinstance(events[-2].event.runner_status, ReadyRunnerStatus) + assert isinstance(events[-2].event.runner_status, InactiveRunnerStatus) assert isinstance(events[-1].event, RunnerStatusUpdated) assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus) @@ -201,7 +187,7 @@ async def test_runner_unassigns( origin=MASTER_NODE_ID ) - await asyncio.sleep(0.5) + await asyncio.sleep(2.0) # already tested by test_runner_assigned_active assert len(worker.assigned_runners) == 1 @@ -210,12 +196,11 @@ async def test_runner_unassigns( # Ensure the correct events have been emitted (creation) events = await global_events.get_events_since(0) - assert len(events) >= 5 + assert len(events) >= 4 assert isinstance(events[-1].event, RunnerStatusUpdated) assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus) # Ensure state is correct - print(worker.state) assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus) await global_events.append_events( @@ -227,7 +212,6 @@ async def test_runner_unassigns( await asyncio.sleep(0.3) - print(worker.state) assert len(worker.assigned_runners) == 0 # Ensure the correct events have been emitted (deletion) @@ -236,221 +220,6 @@ async def test_runner_unassigns( # After deletion, runner should be removed from state.runners assert len(worker.state.runners) == 0 -async def test_runner_inference( - worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task] - ): - _worker, global_events = await worker_running(NODE_A) - - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated( - instance=instance_value, - ), - TaskCreated( - task_id=task.task_id, - task=task - ) - ], - origin=MASTER_NODE_ID - ) - - seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert 'tokyo' in response_string.lower() - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID - ) - - await asyncio.sleep(0.3) - -async def test_2_runner_inference( - logger: Logger, - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task] - ): - event_log_manager = EventLogManager(EventLogConfig(), logger) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() - - global_events = event_log_manager.global_events - await global_events.delete_all_events() - - worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) - asyncio.create_task(worker1.run()) - - worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) - asyncio.create_task(worker2.run()) - - ## Instance - model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1) - }, - node_to_runner={ - NODE_A: RUNNER_1_ID, - NODE_B: RUNNER_2_ID - } - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts(2) - ) - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated( - instance=instance - ), - TaskCreated( - task_id=task.task_id, - task=task - ) - ], - origin=MASTER_NODE_ID - ) - - seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert 'tokyo' in response_string.lower() - - - idx = await global_events.get_last_idx() - await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID - ) - - await asyncio.sleep(2.0) - -async def test_2_runner_multi_message( - logger: Logger, - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - ): - event_log_manager = EventLogManager(EventLogConfig(), logger) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() - - global_events = event_log_manager.global_events - await global_events.delete_all_events() - - worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) - asyncio.create_task(worker1.run()) - - worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) - asyncio.create_task(worker2.run()) - - ## Instance - model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1) - }, - node_to_runner={ - NODE_A: RUNNER_1_ID, - NODE_B: RUNNER_2_ID - } - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts(2) - ) - - # Task - we have three messages here, which is what the task is about - - completion_create_params = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage(role="user", content='What is the capital of France?'), - ChatCompletionMessage(role="assistant", content='The capital of France is Paris.'), - ChatCompletionMessage(role="user", content='Ok great. Now write me a haiku about what you can do there.'), - ], - stream=True, - ) - - task = ChatCompletionTask( - task_id=TASK_1_ID, - command_id=CommandId(), - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=completion_create_params - ) - - await global_events.append_events( - [ - InstanceCreated( - instance=instance - ), - TaskCreated( - task_id=task.task_id, - task=task - ) - ], - origin=MASTER_NODE_ID - ) - - seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert any(keyword in response_string.lower() for keyword in ('kiss', 'paris', 'art', 'love')) - - - idx = await global_events.get_last_idx() - await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID - ) - - await asyncio.sleep(2.0) async def test_runner_respawn( @@ -467,10 +236,10 @@ async def test_runner_respawn( await global_events.delete_all_events() worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) - asyncio.create_task(worker1.run()) + asyncio.create_task(run(worker1)) worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) - asyncio.create_task(worker2.run()) + asyncio.create_task(run(worker2)) ## Instance model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') @@ -534,21 +303,18 @@ async def test_runner_respawn( await asyncio.sleep(5.0) events = await global_events.get_events_since(idx) - print(f'{events=}') # assert len(events) == 2 assert isinstance(events[0].event, RunnerStatusUpdated) assert isinstance(events[0].event.runner_status, FailedRunnerStatus) assert isinstance(events[1].event, RunnerStatusUpdated) - assert isinstance(events[1].event.runner_status, ReadyRunnerStatus) + assert isinstance(events[1].event.runner_status, InactiveRunnerStatus) assert events[1].event.runner_id == RUNNER_2_ID assert isinstance(events[2].event, RunnerStatusUpdated) - assert isinstance(events[2].event.runner_status, ReadyRunnerStatus) + assert isinstance(events[2].event.runner_status, InactiveRunnerStatus) assert events[2].event.runner_id == RUNNER_1_ID - print(worker1.state) - print(worker2.state) for event in [events[3].event, events[4].event]: assert isinstance(event, RunnerStatusUpdated) diff --git a/worker/tests/test_integration/test_inference.py b/worker/tests/test_integration/test_inference.py new file mode 100644 index 00000000..8b291db9 --- /dev/null +++ b/worker/tests/test_integration/test_inference.py @@ -0,0 +1,256 @@ +import asyncio +from logging import Logger +from typing import Awaitable, Callable + +# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py +from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager +from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams +from shared.types.common import CommandId, Host, NodeId +from shared.types.events import ( + InstanceCreated, + InstanceDeleted, + TaskCreated, +) +from shared.types.models import ModelId +from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType +from shared.types.worker.common import InstanceId, RunnerId +from shared.types.worker.instances import ( + Instance, + InstanceStatus, + ShardAssignments, +) +from shared.types.worker.shards import PipelineShardMetadata +from worker.download.shard_downloader import NoopShardDownloader +from worker.main import run +from worker.tests.constants import ( + INSTANCE_1_ID, + MASTER_NODE_ID, + NODE_A, + NODE_B, + RUNNER_1_ID, + RUNNER_2_ID, + TASK_1_ID, +) +from worker.tests.test_integration.integration_utils import ( + read_streaming_response, +) +from worker.worker import Worker + + +async def test_runner_inference( + worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + chat_completion_task: Callable[[InstanceId, TaskId], Task] + ): + _worker, global_events = await worker_running(NODE_A) + + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( + [ + InstanceCreated( + instance=instance_value, + ), + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert 'tokyo' in response_string.lower() + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(0.3) + +async def test_2_runner_inference( + logger: Logger, + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + chat_completion_task: Callable[[InstanceId, TaskId], Task] + ): + event_log_manager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + shard_downloader = NoopShardDownloader() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) + asyncio.create_task(run(worker1)) + + worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) + asyncio.create_task(run(worker2)) + + ## Instance + model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1) + }, + node_to_runner={ + NODE_A: RUNNER_1_ID, + NODE_B: RUNNER_2_ID + } + ) + + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2) + ) + + task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( + [ + InstanceCreated( + instance=instance + ), + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert 'tokyo' in response_string.lower() + + + idx = await global_events.get_last_idx() + await asyncio.sleep(1.0) + events = await global_events.get_events_since(idx) + assert len(events) == 0 + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(2.0) + +# TODO: Multi message parallel +async def test_2_runner_multi_message( + logger: Logger, + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + ): + event_log_manager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + shard_downloader = NoopShardDownloader() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) + asyncio.create_task(run(worker1)) + + worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) + asyncio.create_task(run(worker2)) + + ## Instance + model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1) + }, + node_to_runner={ + NODE_A: RUNNER_1_ID, + NODE_B: RUNNER_2_ID + } + ) + + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2) + ) + + # Task - we have three messages here, which is what the task is about + + completion_create_params = ChatCompletionTaskParams( + model="gpt-4", + messages=[ + ChatCompletionMessage(role="user", content='What is the capital of France?'), + ChatCompletionMessage(role="assistant", content='The capital of France is Paris.'), + ChatCompletionMessage(role="user", content='Ok great. Now write me a haiku about what you can do there.'), + ], + stream=True, + ) + + task = ChatCompletionTask( + task_id=TASK_1_ID, + command_id=CommandId(), + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=completion_create_params + ) + + await global_events.append_events( + [ + InstanceCreated( + instance=instance + ), + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert any(keyword in response_string.lower() for keyword in ('kiss', 'paris', 'art', 'love')) + + + idx = await global_events.get_last_idx() + await asyncio.sleep(1.0) + events = await global_events.get_events_since(idx) + assert len(events) == 0 + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(2.0) diff --git a/worker/tests/test_supervisor_errors.py b/worker/tests/test_integration/test_supervisor_errors.py similarity index 65% rename from worker/tests/test_supervisor_errors.py rename to worker/tests/test_integration/test_supervisor_errors.py index 87390898..4dd62dba 100644 --- a/worker/tests/test_supervisor_errors.py +++ b/worker/tests/test_integration/test_supervisor_errors.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import AsyncGenerator from types import CoroutineType -from typing import Any, Awaitable, Callable, Final +from typing import Any, Awaitable, Callable import pytest from _pytest.monkeypatch import MonkeyPatch @@ -15,11 +15,9 @@ from shared.types.events import ( InstanceDeleted, RunnerStatusUpdated, TaskCreated, - TaskFailed, TaskStateUpdated, ) from shared.types.events.chunks import GenerationChunk, TokenChunk -from shared.types.models import ModelId from shared.types.tasks import Task, TaskId, TaskStatus from shared.types.worker.common import InstanceId, RunnerId from shared.types.worker.instances import ( @@ -29,20 +27,14 @@ from shared.types.worker.instances import ( from shared.types.worker.runners import FailedRunnerStatus from worker.main import Worker from worker.runner.runner_supervisor import RunnerSupervisor +from worker.tests.constants import ( + INSTANCE_1_ID, + MASTER_NODE_ID, + NODE_A, + RUNNER_1_ID, + TASK_1_ID, +) -MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") -NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") -NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") - -# Define constant IDs for deterministic test cases -RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111") -INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222") -RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333") -INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444") -MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' -MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' -TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555") -TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666") @pytest.fixture def user_message(): @@ -187,65 +179,65 @@ async def test_stream_response_failed_once( await asyncio.sleep(0.3) -async def test_stream_response_timeout( - monkeypatch: MonkeyPatch, - worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task] -): - async def mock_stream_response( - self: RunnerSupervisor, - task: Task, - request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, - ) -> AsyncGenerator[GenerationChunk]: - # TODO: Also a test where we yield a few chunks and then time out. - print('sleeping starting') - await asyncio.sleep(4.) - print('sleeping finished') - return - yield +# async def test_stream_response_timeout( +# monkeypatch: MonkeyPatch, +# worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], +# instance: Callable[[InstanceId, NodeId, RunnerId], Instance], +# chat_completion_task: Callable[[InstanceId, TaskId], Task] +# ): +# async def mock_stream_response( +# self: RunnerSupervisor, +# task: Task, +# request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, +# ) -> AsyncGenerator[GenerationChunk]: +# # TODO: Also a test where we yield a few chunks and then time out. +# print('sleeping starting') +# await asyncio.sleep(4.) +# print('sleeping finished') +# return +# yield - monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response) +# monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response) - worker, global_events = await worker_running(NODE_A) +# worker, global_events = await worker_running(NODE_A) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE +# instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) +# instance_value.instance_type = InstanceStatus.ACTIVE - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated(instance=instance_value), - TaskCreated(task_id=task.task_id, task=task) - ], - origin=MASTER_NODE_ID - ) +# task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) +# await global_events.append_events( +# [ +# InstanceCreated(instance=instance_value), +# TaskCreated(task_id=task.task_id, task=task) +# ], +# origin=MASTER_NODE_ID +# ) - await asyncio.sleep(7.) +# await asyncio.sleep(7.) - # as we reset the failures back to zero when we have a successful inference. +# # as we reset the failures back to zero when we have a successful inference. - # print('ASSERTION ERR:') - # print(worker.assigned_runners[RUNNER_1_ID].failures[1][1]) +# # print('ASSERTION ERR:') +# # print(worker.assigned_runners[RUNNER_1_ID].failures[1][1]) - assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0 - assert worker.state.tasks[TASK_1_ID].error_type is None - assert worker.state.tasks[TASK_1_ID].error_message is None +# assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0 +# assert worker.state.tasks[TASK_1_ID].error_type is None +# assert worker.state.tasks[TASK_1_ID].error_message is None - events = await global_events.get_events_since(0) - print(events) - assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1 - assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1 - assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1 +# events = await global_events.get_events_since(0) +# print(events) +# assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1 +# assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1 +# assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1 - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID - ) +# await global_events.append_events( +# [ +# InstanceDeleted( +# instance_id=instance_value.instance_id, +# ), +# ], +# origin=MASTER_NODE_ID +# ) - await asyncio.sleep(0.3) \ No newline at end of file +# await asyncio.sleep(0.3) \ No newline at end of file diff --git a/worker/tests/test_plan/test_worker_plan.py b/worker/tests/test_plan/test_worker_plan.py new file mode 100644 index 00000000..a14521cb --- /dev/null +++ b/worker/tests/test_plan/test_worker_plan.py @@ -0,0 +1,540 @@ +from __future__ import annotations + +import logging + +import pytest + +from shared.types.api import ChatCompletionMessage +from shared.types.state import State +from shared.types.tasks import ( + ChatCompletionTask, + ChatCompletionTaskParams, + TaskStatus, + TaskType, +) +from shared.types.worker.common import NodeStatus +from shared.types.worker.downloads import ( + DownloadPending, +) +from shared.types.worker.instances import InstanceStatus +from shared.types.worker.ops import ( + AssignRunnerOp, + ExecuteTaskOp, + RunnerDownOp, + RunnerUpOp, + UnassignRunnerOp, +) +from shared.types.worker.runners import ( + DownloadingRunnerStatus, + FailedRunnerStatus, + InactiveRunnerStatus, + LoadedRunnerStatus, + RunningRunnerStatus, +) +from shared.types.worker.shards import PipelineShardMetadata +from worker.common import AssignedRunner +from worker.download.shard_downloader import NoopShardDownloader +from worker.main import Worker +from worker.plan import plan +from worker.tests.constants import ( + COMMAND_1_ID, + INSTANCE_1_ID, + MODEL_A_ID, + NODE_A, + NODE_B, + RUNNER_1_ID, + RUNNER_2_ID, + TASK_1_ID, +) +from worker.tests.test_plan.test_worker_plan_utils import ( + InProcessRunner, + PlanTestCase, + make_downloading_status, + make_model_meta, + make_state, + make_test_case, +) + +""" +The idea with these tests is to define declaratively the input and expected output of the worker.plan function. + +We initialize a Worker with InProcessRunners. We then construct a State which gets passed to Worker.plan. +We then check what operation is returned by Worker.plan. + +Note that the 'self' node will always be NODE_A. This leads to the swapped-around cases when checking failure cases etc. +""" + + +def _get_test_cases() -> list[PlanTestCase]: + # The `model_path` for `RUNNER_1_ID` must exist for the `DownloadOp` test case to pass validation. + model_a_meta = make_model_meta(MODEL_A_ID) + return [ + PlanTestCase( + description="no runners -> no-op", + in_process_runners=[], + state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}), + expected_op=None, + ), + + # Both 'assigned' and 'downloading' should be blocking ops - so if we are in either of these we should unassign to retry. + # This needs to change when we move to an async worker + make_test_case( + description="runner state assigned, runner is assigned and downloading -> unassign", + runner_specs=[{ + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': make_downloading_status(NODE_A), + 'downloaded': False + }], + instance_status=InstanceStatus.INACTIVE, + expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID), + ), + + make_test_case( + description="ready runner, model present -> no-op", + runner_specs=[{ + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': InactiveRunnerStatus(), + 'downloaded': True + }], + instance_status=InstanceStatus.INACTIVE, + expected_op=None, + ), + + PlanTestCase( + description="runner assigned and not in state -> AssignRunnerOp", + in_process_runners=[], + state=make_state( + runner_specs_per_instance={ + INSTANCE_1_ID: [(RUNNER_1_ID, NODE_A, 0, InactiveRunnerStatus())] + }, + model_id=MODEL_A_ID, + instance_status=InstanceStatus.ACTIVE, # Either active or inactive should yield the same. + ), + expected_op=AssignRunnerOp( + instance_id=INSTANCE_1_ID, + runner_id=RUNNER_1_ID, + shard_metadata=PipelineShardMetadata( + device_rank=0, + world_size=1, + model_meta=model_a_meta, + start_layer=0, + end_layer=1, + n_layers=1, + ), + hosts=[] + ), + ), + + PlanTestCase( + description="runner assigned but no longer in state -> UnassignRunnerOp", + in_process_runners=[ + InProcessRunner( + runner_id=RUNNER_1_ID, + instance_id=INSTANCE_1_ID, + model_id=MODEL_A_ID, + status=InactiveRunnerStatus(), + downloaded=False, + ) + ], + state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}), + expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID), + ), + + make_test_case( + description="ready runner (and state up) -> expect RunnerUpOp", + runner_specs=[{ + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': InactiveRunnerStatus(), + 'downloaded': True + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=RunnerUpOp(runner_id=RUNNER_1_ID), + ), + + make_test_case( + description="1 ready, 1 downloading (and state up) -> no-op", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': InactiveRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 1, + 'status': DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A)), + 'downloaded': False + } + ], + tasks=[{ + 'task_id': TASK_1_ID, + 'instance_id': INSTANCE_1_ID, + 'status': TaskStatus.PENDING, + 'messages': [{'role': 'user', 'content': 'Hello, world!'}] + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=None + ), + + make_test_case( + description="2 ready runners (and state up) -> expect RunnerUpOp", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': InactiveRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 1, + 'status': InactiveRunnerStatus(), + 'downloaded': True + } + ], + tasks=[{ + 'task_id': TASK_1_ID, + 'instance_id': INSTANCE_1_ID, + 'status': TaskStatus.PENDING, + 'messages': [{'role': 'user', 'content': 'Hello, world!'}] + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=RunnerUpOp(runner_id=RUNNER_1_ID) + ), + + make_test_case( + description="loaded runner (and state down) -> expect RunnerDownOp", + runner_specs=[{ + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': LoadedRunnerStatus(), + 'downloaded': True + }], + instance_status=InstanceStatus.INACTIVE, + expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), + ), + + make_test_case( + description="failed runner (and state down) -> expect RunnerDownOp", + runner_specs=[{ + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': FailedRunnerStatus(), + 'downloaded': True + }], + instance_status=InstanceStatus.INACTIVE, + expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), + ), + + make_test_case( + description="loaded runner, model present, task pending -> expect ExecuteTaskOp", + runner_specs=[{ + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': LoadedRunnerStatus(), + 'downloaded': True + }], + tasks=[{ + 'task_id': TASK_1_ID, + 'instance_id': INSTANCE_1_ID, + 'status': TaskStatus.PENDING, + 'messages': [{'role': 'user', 'content': 'Hello, world!'}] + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask( + task_id=TASK_1_ID, + command_id=COMMAND_1_ID, + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=ChatCompletionTaskParams( + model=str(MODEL_A_ID), + messages=[ChatCompletionMessage(role="user", content="Hello, world!")] + ), + )), + ), + + # We should only run rank 0 once all other ranks are running. + make_test_case( + description="two loaded runners & task, i'm rank 0 -> no-op", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': LoadedRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 1, + 'status': LoadedRunnerStatus(), + 'downloaded': True + } + ], + tasks=[{ + 'task_id': TASK_1_ID, + 'instance_id': INSTANCE_1_ID, + 'status': TaskStatus.PENDING, + 'messages': [{'role': 'user', 'content': 'Hello, world!'}] + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=None + ), + + make_test_case( + description="two loaded runners & task, i'm rank 1 -> expect ExecuteTaskOp on rank 1", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 1, + 'status': LoadedRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 0, + 'status': LoadedRunnerStatus(), + 'downloaded': True + } + ], + tasks=[{ + 'task_id': TASK_1_ID, + 'instance_id': INSTANCE_1_ID, + 'status': TaskStatus.PENDING, + 'messages': [{'role': 'user', 'content': 'Hello, world!'}] + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=ExecuteTaskOp( + runner_id=RUNNER_1_ID, + task=ChatCompletionTask( + task_id=TASK_1_ID, + command_id=COMMAND_1_ID, + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_params=ChatCompletionTaskParams( + model=str(MODEL_A_ID), + messages=[ChatCompletionMessage(role="user", content="Hello, world!")], + ), + task_status=TaskStatus.PENDING, + ), + ), + ), + + make_test_case( + description="rank 1 loaded, rank 0 ready, i'm rank 0 -> expect ExecuteTaskOp on rank 0", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': LoadedRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 1, + 'status': RunningRunnerStatus(), + 'downloaded': True + } + ], + tasks=[{ + 'task_id': TASK_1_ID, + 'instance_id': INSTANCE_1_ID, + 'status': TaskStatus.PENDING, + 'messages': [{'role': 'user', 'content': 'Hello, world!'}] + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=ExecuteTaskOp( + runner_id=RUNNER_1_ID, + task=ChatCompletionTask( + task_id=TASK_1_ID, + command_id=COMMAND_1_ID, + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_params=ChatCompletionTaskParams( + model=str(MODEL_A_ID), + messages=[ChatCompletionMessage(role="user", content="Hello, world!")], + ), + task_status=TaskStatus.PENDING, + ), + ), + ), + + make_test_case( + description="this runner failed (1 node) -> RunnerDownOp", + runner_specs=[{ + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': FailedRunnerStatus(), + 'downloaded': True + }], + instance_status=InstanceStatus.ACTIVE, + expected_op=RunnerDownOp(runner_id=RUNNER_1_ID) + ), + + make_test_case( + description="other runner failed -> RunnerDownOp", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': LoadedRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 1, + 'status': FailedRunnerStatus(), + 'downloaded': True + } + ], + instance_status=InstanceStatus.ACTIVE, + expected_op=RunnerDownOp(runner_id=RUNNER_1_ID) + ), + + + make_test_case( + description="this runner failed (2 nodes) -> no-op", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': FailedRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 1, + 'status': LoadedRunnerStatus(), + 'downloaded': True + } + ], + instance_status=InstanceStatus.ACTIVE, + expected_op=None + ), + + make_test_case( + description="this node failed, other node spun down -> RunnerDownOp", + runner_specs=[ + { + 'runner_id': RUNNER_1_ID, + 'node_id': NODE_A, + 'device_rank': 0, + 'status': FailedRunnerStatus(), + 'downloaded': True + }, + { + 'runner_id': RUNNER_2_ID, + 'node_id': NODE_B, + 'device_rank': 1, + 'status': InactiveRunnerStatus(), + 'downloaded': True + } + ], + instance_status=InstanceStatus.ACTIVE, + expected_op=RunnerDownOp(runner_id=RUNNER_1_ID) + ), + + ] + + +# --------------------------------------------------------------------------- +# Parametrised test +# --------------------------------------------------------------------------- + + +# Pre-compute readable identifiers for each case to avoid lambda typing issues. +@pytest.mark.parametrize( + "case", + # We use a factory to delay test case generation until tmp_path is available. + [pytest.param(c, id=c.id()) for c in _get_test_cases()], +) +def test_worker_plan(case: PlanTestCase) -> None: + """Exercise Worker.plan across declarative scenarios.""" + + print(f"----- case: {case.description}") + + # Regenerate test cases with the actual tmp_path fixture + test_cases = {c.description: c for c in _get_test_cases()} + case = test_cases[case.description] + + node_id = NODE_A + + logger = logging.getLogger("test_worker_plan") + shard_downloader = NoopShardDownloader() + worker = Worker(node_id=node_id, shard_downloader=shard_downloader, worker_events=None, global_events=None, logger=logger) + + runner_config: InProcessRunner + for runner_config in case.in_process_runners: + + if len(case.state.instances) == 1: + instance_id = next(iter(case.state.instances)) + + shard_assignments = case.state.instances[instance_id].shard_assignments + shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id] + + # Only add this runner if it belongs to our node + runner_node = None + for node, runner in shard_assignments.node_to_runner.items(): + if runner == runner_config.runner_id: + runner_node = node + break + + if runner_node != node_id: + # This runner belongs to a different node, skip it + continue + + elif len(case.state.instances) == 0: + shard_metadata = PipelineShardMetadata( + device_rank=runner_config.device_rank, + world_size=1, + model_meta=make_model_meta(runner_config.model_id), + start_layer=0, + end_layer=1, + n_layers=1, + ) + else: + raise Exception('test_worker_plan not currently designed to have more than 1 instance.') + + + assigned_runner = AssignedRunner( + runner_id=runner_config.runner_id, + instance_id=runner_config.instance_id, + shard_metadata=shard_metadata, + hosts=[], + status=runner_config.status, + runner=None, + ) + worker.assigned_runners[runner_config.runner_id] = assigned_runner + + op = plan(worker.assigned_runners, + NODE_A, + case.state.instances, + case.state.runners, + case.state.tasks, + ) + assert op == case.expected_op diff --git a/worker/tests/test_plan/test_worker_plan_utils.py b/worker/tests/test_plan/test_worker_plan_utils.py new file mode 100644 index 00000000..49283013 --- /dev/null +++ b/worker/tests/test_plan/test_worker_plan_utils.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, NotRequired, Optional, TypedDict + +from typing_extensions import Literal + +from shared.models.model_cards import MODEL_CARDS, ModelCard +from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams +from shared.types.common import CommandId, NodeId +from shared.types.models import ModelId, ModelMetadata +from shared.types.state import State +from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType +from shared.types.worker.common import InstanceId, NodeStatus, RunnerId +from shared.types.worker.downloads import DownloadOngoing, DownloadProgressData +from shared.types.worker.instances import Instance, InstanceStatus +from shared.types.worker.ops import RunnerOp +from shared.types.worker.runners import ( + DownloadingRunnerStatus, + RunnerStatus, + RunningRunnerStatus, + ShardAssignments, +) +from shared.types.worker.shards import PipelineShardMetadata +from worker.tests.constants import COMMAND_1_ID, INSTANCE_1_ID, MODEL_A_ID + + +class RunnerSpecDict(TypedDict): + """Type definition for runner specification dictionaries.""" + runner_id: RunnerId + node_id: NodeId + device_rank: int + status: RunnerStatus + downloaded: NotRequired[bool] # defaults to True if not provided + + +class MessageDict(TypedDict): + """Type definition for message dictionaries.""" + role: Literal["system", "user", "assistant", "developer", "tool", "function"] + content: NotRequired[str | None] + name: NotRequired[str | None] + tool_calls: NotRequired[list[dict[str, str]] | None] + tool_call_id: NotRequired[str | None] + function_call: NotRequired[dict[str, str] | None] + + +class TaskSpecDict(TypedDict): + """Type definition for task specification dictionaries.""" + task_id: TaskId + instance_id: NotRequired[InstanceId] # defaults to function parameter if not provided + command_id: NotRequired[CommandId] # defaults to COMMAND_1_ID if not provided + status: NotRequired[TaskStatus] # defaults to TaskStatus.PENDING if not provided + model: NotRequired[str] # defaults to model_id if not provided + messages: NotRequired[list[MessageDict]] # defaults to [{'role': 'user', 'content': 'Hello, world!'}] if not provided + + +@dataclass(slots=True, frozen=True) +class InProcessRunner: + """Minimal description of a runner's in-process state.""" + + runner_id: RunnerId + instance_id: InstanceId + model_id: ModelId + status: RunnerStatus + downloaded: bool + device_rank: int = 0 + + +@dataclass(slots=True, frozen=True) +class PlanTestCase: + """Table-driven description of an entire planning scenario.""" + + description: str + state: State + in_process_runners: List[InProcessRunner] + expected_op: Optional[RunnerOp] + + def id(self) -> str: # noqa: D401 + return self.description.replace(" ", "_") + + +def make_shard_metadata(device_rank: int, world_size: int, model_id: ModelId = MODEL_A_ID) -> PipelineShardMetadata: + """Create PipelineShardMetadata with proper layer assignments based on device_rank and world_size.""" + total_layers = world_size # For simplicity in tests, total_layers = world_size + + if world_size == 1: + start_layer = 0 + end_layer = 1 + n_layers = 1 + else: + # For multi-device setup, each device gets one layer + start_layer = device_rank + end_layer = device_rank + 1 + n_layers = total_layers + + return PipelineShardMetadata( + device_rank=device_rank, + world_size=world_size, + model_meta=make_model_meta(model_id), + start_layer=start_layer, + end_layer=end_layer, + n_layers=n_layers, + ) + + +def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus: + """Factory for a *Downloading* status with placeholder progress.""" + return DownloadingRunnerStatus( + download_progress=DownloadOngoing( + node_id=node_id, + download_progress=DownloadProgressData(total_bytes=1, downloaded_bytes=0), + ) + ) + +def make_model_meta( + model_id: str +) -> ModelMetadata: + model_card: ModelCard + for card in MODEL_CARDS.values(): + if card.model_id == model_id: + model_card = card + + return ModelMetadata( + model_id=model_id, + pretty_name=model_card.model_id, + storage_size_kilobytes=10**6, + n_layers=16, + ) + + raise Exception(f'Unknown model_id passed: {model_id}') + + ## Alternatively, if we are ok for this method to be async: + # await _get_model_meta(model_id) + + + +def make_instance( + instance_id: InstanceId, + runner_specs: list[tuple[RunnerId, NodeId, int, RunnerStatus]], + model_id: ModelId = MODEL_A_ID, + instance_status: InstanceStatus = InstanceStatus.ACTIVE, +) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, NodeStatus]]: + """Creates an instance with one or more runners.""" + runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {} + node_to_runner: dict[NodeId, RunnerId] = {} + world_size = len(runner_specs) + + for runner_id, node_id, device_rank, _ in runner_specs: + shard_metadata = make_shard_metadata( + device_rank, + world_size, + model_id + ) + runner_to_shard[runner_id] = shard_metadata + node_to_runner[node_id] = runner_id + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard=runner_to_shard, + node_to_runner=node_to_runner, + ) + instance = Instance( + instance_id=instance_id, + instance_type=instance_status, + shard_assignments=shard_assignments, + hosts=[], + ) + + # Currently nodes are only ever idle - as if they were running we would be blocking - so we wouldn't be running plan() + # node_statuses = {node_id: NodeStatus.Idle for _, node_id, _, _ in runner_specs} + node_statuses: dict[NodeId, NodeStatus] = {} + for _runner_id, node_id, _, status in runner_specs: + if isinstance(status, RunningRunnerStatus): + node_statuses[node_id] = NodeStatus.Running + else: + node_statuses[node_id] = NodeStatus.Idle + runner_statuses = {runner_id: status for runner_id, _, _, status in runner_specs} + + return instance, runner_statuses, node_statuses + +def make_state( + runner_specs_per_instance: dict[InstanceId, list[tuple[RunnerId, NodeId, int, RunnerStatus]]], + tasks: dict[TaskId, ChatCompletionTask] | None = None, + model_id: ModelId = MODEL_A_ID, + instance_status: InstanceStatus = InstanceStatus.ACTIVE, +) -> State: + """Builds a full State from runner specs per instance, tasks, and defaults.""" + if tasks is None: + tasks = {} + instances: dict[InstanceId, Instance] = {} + all_runner_statuses: dict[RunnerId, RunnerStatus] = {} + all_node_statuses: dict[NodeId, NodeStatus] = {} + + for inst_id, specs in runner_specs_per_instance.items(): + # Build per-instance data using make_instance + instance, runner_statuses, node_statuses = make_instance( + instance_id=inst_id, + runner_specs=specs, + model_id=model_id, + instance_status=instance_status, + ) + instances[inst_id] = instance + all_runner_statuses.update(runner_statuses) + all_node_statuses.update(node_statuses) + + return State( + node_status=all_node_statuses, + instances=instances, + runners=all_runner_statuses, + tasks=tasks, + ) + +def make_test_case( + description: str, + runner_specs: list[RunnerSpecDict], + tasks: list[TaskSpecDict] | None = None, + expected_op: Optional[RunnerOp] = None, + instance_id: InstanceId = INSTANCE_1_ID, + instance_status: InstanceStatus = InstanceStatus.ACTIVE, + model_id: ModelId = MODEL_A_ID, + command_id: CommandId = COMMAND_1_ID, # Default for tasks +) -> PlanTestCase: + """Builds a PlanTestCase from high-level specs.""" + if tasks is None: + tasks = [] + # Convert runner_specs to tuple format for make_instance + specs_tuple = [ + (r['runner_id'], r['node_id'], r['device_rank'], r['status']) + for r in runner_specs + ] + + # Build state using make_state (wrap single instance) + state_tasks: dict[TaskId, ChatCompletionTask] = {} + for t in tasks: + task = ChatCompletionTask( + instance_id=instance_id, + task_id=t['task_id'], + command_id=t.get('command_id', command_id), + task_type=TaskType.CHAT_COMPLETION, + task_status=t.get('status', TaskStatus.PENDING), + task_params=ChatCompletionTaskParams( + model=t.get('model', str(model_id)), + messages=[ChatCompletionMessage(**m) for m in t.get('messages', [{'role': 'user', 'content': 'Hello, world!'}])], + ), + ) + state_tasks[t['task_id']] = task + + state = make_state( + runner_specs_per_instance={instance_id: specs_tuple}, + tasks=state_tasks, + model_id=model_id, + instance_status=instance_status, + ) + + # Build in_process_runners with downloaded (default True if missing) + in_process_runners = [ + InProcessRunner( + runner_id=r['runner_id'], + instance_id=instance_id, + model_id=model_id, + status=r['status'], + downloaded=r.get('downloaded', True), + device_rank=r['device_rank'], + ) for r in runner_specs + ] + + return PlanTestCase( + description=description, + state=state, + in_process_runners=in_process_runners, + expected_op=expected_op, + ) \ No newline at end of file diff --git a/worker/tests/test_runner_connection.py b/worker/tests/test_runner_connection.py index c988224b..17ddfe79 100644 --- a/worker/tests/test_runner_connection.py +++ b/worker/tests/test_runner_connection.py @@ -9,13 +9,13 @@ from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager from shared.types.common import Host, NodeId from shared.types.events import InstanceCreated, InstanceDeleted from shared.types.models import ModelId -from shared.types.tasks import Task from shared.types.worker.common import InstanceId, RunnerId from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments from shared.types.worker.runners import FailedRunnerStatus from shared.types.worker.shards import PipelineShardMetadata from worker.download.shard_downloader import NoopShardDownloader -from worker.main import Worker +from worker.main import run +from worker.worker import Worker MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") @@ -42,7 +42,6 @@ async def check_runner_connection( logger: Logger, pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]], - chat_completion_task: Callable[[InstanceId, str], Task], ) -> bool: # Track all tasks and workers for cleanup tasks: list[asyncio.Task[None]] = [] @@ -64,7 +63,7 @@ async def check_runner_connection( global_events=global_events, ) workers.append(worker1) - task1 = asyncio.create_task(worker1.run()) + task1 = asyncio.create_task(run(worker1)) tasks.append(task1) worker2 = Worker( @@ -75,7 +74,7 @@ async def check_runner_connection( global_events=global_events, ) workers.append(worker2) - task2 = asyncio.create_task(worker2.run()) + task2 = asyncio.create_task(run(worker2)) tasks.append(task2) model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') @@ -151,39 +150,41 @@ async def check_runner_connection( # Check Running status -def test_runner_connection_stress( - logger: Logger, - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - chat_completion_task: Callable[[InstanceId, str], Task], -) -> None: - total_runs = 100 - successes = 0 +# # not now. + +# def test_runner_connection_stress( +# logger: Logger, +# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], +# hosts: Callable[[int], list[Host]], +# chat_completion_task: Callable[[InstanceId, str], Task], +# ) -> None: +# total_runs = 100 +# successes = 0 - for _ in range(total_runs): - # Create a fresh event loop for each iteration - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) +# for _ in range(total_runs): +# # Create a fresh event loop for each iteration +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete(check_runner_connection( - logger=logger, - pipeline_shard_meta=pipeline_shard_meta, - hosts=hosts, - chat_completion_task=chat_completion_task, - )) - if result: - successes += 1 - finally: - # Cancel all running tasks - pending = asyncio.all_tasks(loop) - for task in pending: - task.cancel() +# try: +# result = loop.run_until_complete(check_runner_connection( +# logger=logger, +# pipeline_shard_meta=pipeline_shard_meta, +# hosts=hosts, +# chat_completion_task=chat_completion_task, +# )) +# if result: +# successes += 1 +# finally: +# # Cancel all running tasks +# pending = asyncio.all_tasks(loop) +# for task in pending: +# task.cancel() - # Run the event loop briefly to allow cancellation to complete - loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) +# # Run the event loop briefly to allow cancellation to complete +# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - # Close the event loop - loop.close() +# # Close the event loop +# loop.close() - print(f"Runner connection successes: {successes} / {total_runs}") +# print(f"Runner connection successes: {successes} / {total_runs}") diff --git a/worker/tests/test_serdes.py b/worker/tests/test_serdes.py index 67782e4f..29484833 100644 --- a/worker/tests/test_serdes.py +++ b/worker/tests/test_serdes.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Callable, TypeVar from pydantic import BaseModel, TypeAdapter @@ -28,7 +27,6 @@ def assert_equal_serdes(obj: T, typeadapter: TypeAdapter[T]): def test_supervisor_setup_message_serdes( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - tmp_path: Path, ): setup_message = SetupMessage( model_shard_meta=pipeline_shard_meta(1, 0), diff --git a/worker/tests/test_spinup_timeout.py b/worker/tests/test_spinup_timeout.py index f8966d8e..c01363fa 100644 --- a/worker/tests/test_spinup_timeout.py +++ b/worker/tests/test_spinup_timeout.py @@ -10,13 +10,13 @@ from shared.types.events import ( ) from shared.types.events._events import RunnerStatusUpdated from shared.types.tasks import Task, TaskId -from shared.types.worker.common import RunnerId from shared.types.worker.instances import Instance, InstanceId from shared.types.worker.ops import ( RunnerUpOp, ) from shared.types.worker.runners import FailedRunnerStatus from worker.main import Worker +from worker.tests.constants import RUNNER_1_ID # To enable this test, run pytest with: ENABLE_SPINUP_TIMEOUT_TEST=true pytest @@ -26,13 +26,13 @@ from worker.main import Worker ) @pytest.mark.asyncio async def test_runner_up_op_timeout( - worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], + worker_with_assigned_runner: tuple[Worker, Instance], chat_completion_task: Callable[[InstanceId, TaskId], Task], monkeypatch: pytest.MonkeyPatch ): - worker, runner_id, _ = worker_with_assigned_runner + worker, _ = worker_with_assigned_runner - runner_up_op = RunnerUpOp(runner_id=runner_id) + runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) # _execute_runner_up_op should throw a TimeoutError with a short timeout events: list[Event] = [] diff --git a/worker/tests/test_supervisor.py b/worker/tests/test_supervisor/test_supervisor.py similarity index 98% rename from worker/tests/test_supervisor.py rename to worker/tests/test_supervisor/test_supervisor.py index 915c7393..59ddcf91 100644 --- a/worker/tests/test_supervisor.py +++ b/worker/tests/test_supervisor/test_supervisor.py @@ -1,6 +1,5 @@ import asyncio from logging import Logger -from pathlib import Path from typing import Callable import pytest @@ -30,7 +29,6 @@ async def test_supervisor_single_node_response( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - tmp_path: Path, logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" @@ -70,7 +68,6 @@ async def test_supervisor_two_node_response( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - tmp_path: Path, logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" @@ -133,7 +130,6 @@ async def test_supervisor_early_stopping( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], - tmp_path: Path, logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" @@ -189,7 +185,6 @@ async def test_supervisor_handles_terminated_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], logger: Logger, - tmp_path: Path, ): """Test that the supervisor handles a terminated runner""" model_shard_meta = pipeline_shard_meta(1, 0) @@ -214,7 +209,6 @@ async def test_supervisor_handles_terminated_runner( async def test_supervisor_handles_killed_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - tmp_path: Path, logger: Logger, ): """Test that the supervisor handles a killed runner""" diff --git a/worker/tests/test_worker_handlers.py b/worker/tests/test_worker_handlers.py deleted file mode 100644 index bc145db7..00000000 --- a/worker/tests/test_worker_handlers.py +++ /dev/null @@ -1,237 +0,0 @@ -## Tests for worker state handlers - -from pathlib import Path -from typing import Callable - -import pytest - -from shared.types.common import NodeId -from shared.types.events import ( - ChunkGenerated, - Event, - RunnerDeleted, - RunnerStatusUpdated, - TaskFailed, - TaskStateUpdated, -) -from shared.types.events.chunks import TokenChunk -from shared.types.tasks import Task, TaskId, TaskStatus -from shared.types.worker.common import RunnerId -from shared.types.worker.instances import Instance, InstanceId -from shared.types.worker.ops import ( - AssignRunnerOp, - DownloadOp, - ExecuteTaskOp, - RunnerDownOp, - RunnerUpOp, - UnassignRunnerOp, -) -from shared.types.worker.runners import ( - AssignedRunnerStatus, - FailedRunnerStatus, - LoadedRunnerStatus, - ReadyRunnerStatus, - RunningRunnerStatus, -) -from worker.main import Worker - - -@pytest.fixture -def user_message(): - """Override the default message to ask about France's capital""" - return "What, according to Douglas Adams, is the meaning of life, the universe and everything?" - -@pytest.mark.asyncio -async def test_assign_op(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], tmp_path: Path): - runner_id = RunnerId() - instance_obj: Instance = instance(InstanceId(), worker.node_id, runner_id) - - assign_op = AssignRunnerOp( - runner_id=runner_id, - shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], - hosts=instance_obj.hosts, - instance_id=instance_obj.instance_id, - ) - - events: list[Event] = [] - - async for event in worker._execute_op(assign_op): # type: ignore[misc] - events.append(event) - - # We should have a status update saying 'starting'. - assert len(events) == 1 - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, AssignedRunnerStatus) - - # And the runner should be assigned - assert runner_id in worker.assigned_runners - assert isinstance(worker.assigned_runners[runner_id].status, AssignedRunnerStatus) - -@pytest.mark.asyncio -async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path): - worker, runner_id, _ = worker_with_assigned_runner - - unassign_op = UnassignRunnerOp( - runner_id=runner_id - ) - - events: list[Event] = [] - - async for event in worker._execute_op(unassign_op): # type: ignore[misc] - events.append(event) - - # We should have no assigned runners and no events were emitted - assert len(worker.assigned_runners) == 0 - assert len(events) == 1 - assert isinstance(events[0], RunnerDeleted) - -@pytest.mark.asyncio -async def test_runner_up_op( - worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - tmp_path: Path - ): - worker, runner_id, _ = worker_with_assigned_runner - - runner_up_op = RunnerUpOp(runner_id=runner_id) - - events: list[Event] = [] - async for event in worker._execute_op(runner_up_op): # type: ignore[misc] - events.append(event) - - assert len(events) == 1 - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, LoadedRunnerStatus) - - # Is the runner actually running? - supervisor = next(iter(worker.assigned_runners.values())).runner - assert supervisor is not None - assert supervisor.healthy - - full_response = '' - - async for chunk in supervisor.stream_response(task=chat_completion_task(InstanceId(), TaskId())): - if isinstance(chunk, TokenChunk): - full_response += chunk.text - - assert "42" in full_response.lower(), ( - f"Expected '42' in response, but got: {full_response}" - ) - - runner = worker.assigned_runners[runner_id].runner - assert runner is not None - await runner.astop() # Neat cleanup. - -@pytest.mark.asyncio -async def test_runner_down_op(worker_with_running_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path): - worker, runner_id, _ = worker_with_running_runner - - runner_down_op = RunnerDownOp(runner_id=runner_id) - events: list[Event] = [] - async for event in worker._execute_op(runner_down_op): # type: ignore[misc] - events.append(event) - - assert len(events) == 1 - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, ReadyRunnerStatus) - -@pytest.mark.asyncio -async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path): - worker, runner_id, instance_obj = worker_with_assigned_runner - - print(f'{worker.assigned_runners=}') - - download_op = DownloadOp( - instance_id=instance_obj.instance_id, - runner_id=runner_id, - shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], - hosts=instance_obj.hosts, - ) - - events: list[Event] = [] - - async for event in worker._execute_op(download_op): # type: ignore[misc] - events.append(event) - - # Should give download status and then a final download status with DownloadCompleted - print(events) - -@pytest.mark.asyncio -async def test_execute_task_op( - worker_with_running_runner: tuple[Worker, RunnerId, Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], tmp_path: Path): - worker, runner_id, _ = worker_with_running_runner - - execute_task_op = ExecuteTaskOp( - runner_id=runner_id, - task=chat_completion_task(InstanceId(), TaskId()) - ) - - events: list[Event] = [] - async for event in worker._execute_op(execute_task_op): # type: ignore[misc] - events.append(event) - - assert len(events) > 20 - - print(f'{events=}') - - - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, RunningRunnerStatus) - - assert isinstance(events[1], TaskStateUpdated) - assert events[1].task_status == TaskStatus.RUNNING # It tried to start. - - assert isinstance(events[-2], TaskStateUpdated) - assert events[-2].task_status == TaskStatus.COMPLETE # It tried to start. - - assert isinstance(events[-1], RunnerStatusUpdated) - assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed. - - gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)] - text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)] - assert len(text_chunks) == len(events) - 4 - - output_text = ''.join([x.text for x in text_chunks]) - assert '42' in output_text - - runner = worker.assigned_runners[runner_id].runner - assert runner is not None - await runner.astop() # Neat cleanup. - -@pytest.mark.asyncio -async def test_execute_task_fails( - worker_with_running_runner: tuple[Worker, RunnerId, Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], tmp_path: Path): - worker, runner_id, _ = worker_with_running_runner - - task = chat_completion_task(InstanceId(), TaskId()) - messages = task.task_params.messages - messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL' - - execute_task_op = ExecuteTaskOp( - runner_id=runner_id, - task=task - ) - - events: list[Event] = [] - async for event in worker._execute_op(execute_task_op): # type: ignore[misc] - events.append(event) - - assert len(events) == 5 - - print(events) - - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, RunningRunnerStatus) # It tried to start. - - assert isinstance(events[1], TaskStateUpdated) - assert events[1].task_status == TaskStatus.RUNNING # It tried to start. - - assert isinstance(events[2], TaskStateUpdated) - assert events[2].task_status == TaskStatus.FAILED # Task marked as failed. - - assert isinstance(events[3], TaskFailed) - - assert isinstance(events[4], RunnerStatusUpdated) - assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed. \ No newline at end of file diff --git a/worker/tests/test_worker_plan.py b/worker/tests/test_worker_plan.py deleted file mode 100644 index 040d47ee..00000000 --- a/worker/tests/test_worker_plan.py +++ /dev/null @@ -1,913 +0,0 @@ -from __future__ import annotations - -import logging -import tempfile -from pathlib import Path - -import pytest - -from shared.types.api import ChatCompletionMessage -from shared.types.state import State -from shared.types.tasks import ( - ChatCompletionTask, - ChatCompletionTaskParams, - TaskStatus, - TaskType, -) -from shared.types.worker.common import NodeStatus -from shared.types.worker.downloads import DownloadPending -from shared.types.worker.instances import Instance, InstanceStatus -from shared.types.worker.ops import ( - AssignRunnerOp, - DownloadOp, - ExecuteTaskOp, - RunnerDownOp, - RunnerUpOp, - UnassignRunnerOp, -) -from shared.types.worker.runners import ( - AssignedRunnerStatus, - DownloadingRunnerStatus, - FailedRunnerStatus, - LoadedRunnerStatus, - ReadyRunnerStatus, - RunningRunnerStatus, - ShardAssignments, -) -from shared.types.worker.shards import PipelineShardMetadata -from worker.download.download_utils import build_model_path -from worker.download.shard_downloader import NoopShardDownloader -from worker.main import AssignedRunner, Worker - -from .test_worker_plan_utils import ( - COMMAND_1_ID, - INSTANCE_1_ID, - MODEL_A_ID, - NODE_A, - NODE_B, - RUNNER_1_ID, - RUNNER_2_ID, - TASK_1_ID, - InProcessRunner, - PlanTestCase, - make_downloading_status, - make_model_meta, - make_shard_metadata, -) - -""" -The idea with these tests is to define declaratively the input and expected output of the worker.plan function. - -We initialize a Worker with InProcessRunners. We then construct a State which gets passed to Worker.plan. -We then check what operation is returned by Worker.plan. -""" - -def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: - # The `model_path` for `RUNNER_1_ID` must exist for the `DownloadOp` test case to pass validation. - (tmp_path / f"model_for_runner_{RUNNER_1_ID}").mkdir(exist_ok=True, parents=True) - model_a_meta = make_model_meta(MODEL_A_ID) - return [ - PlanTestCase( - description="no runners -> no-op", - in_process_runners=[], - state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}), - expected_op=None, - ), - - # I don't think this should ever happen, as if it's currently downloading then the worker loop will be blocked - # Potentially useful for future compatibility when worker becomes non-blocking - PlanTestCase( - description="runner state assigned, runner is assigned and downloading -> no-op", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=make_downloading_status(NODE_A), - downloaded=False, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.INACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: make_downloading_status(NODE_A)}, - ), - expected_op=None, - ), - - PlanTestCase( - description="runner state downloading, runner is downloading -> no-op", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=make_downloading_status(NODE_A), - downloaded=False, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.INACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: make_downloading_status(NODE_A)}, - ), - expected_op=None, - ), - - PlanTestCase( - description="ready runner, model present -> no-op", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=ReadyRunnerStatus(), - downloaded=True, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.INACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: ReadyRunnerStatus()}, - ), - expected_op=None, - ), - - PlanTestCase( - description="runner assigned and not in state -> AssignRunnerOp", - in_process_runners=[], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, # Either active or inactive should yield the same. - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: AssignedRunnerStatus()}, - ), - expected_op=AssignRunnerOp( - instance_id=INSTANCE_1_ID, - runner_id=RUNNER_1_ID, - shard_metadata=PipelineShardMetadata( - device_rank=0, - world_size=1, - model_meta=model_a_meta, - start_layer=0, - end_layer=1, - n_layers=1, - ), - hosts=[] - ), - ), - - PlanTestCase( - description="runner assigned but no longer in state -> UnassignRunnerOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=AssignedRunnerStatus(), - downloaded=False, - ) - ], - state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}), - expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID), - ), - - PlanTestCase( - description="runner state assigned, runner is assigned, not downloaded -> expect DownloadOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=AssignedRunnerStatus(), - downloaded=False, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: AssignedRunnerStatus()}, - ), - expected_op=DownloadOp( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - shard_metadata=PipelineShardMetadata( - device_rank=0, - world_size=1, - model_meta=model_a_meta, - start_layer=0, - end_layer=1, - n_layers=1, - ), - hosts=[], - ), - ), - - PlanTestCase( - description="ready runner (and state up) -> expect RunnerUpOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=ReadyRunnerStatus(), - downloaded=True, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: ReadyRunnerStatus()}, - tasks={}, - ), - expected_op=RunnerUpOp(runner_id=RUNNER_1_ID), - ), - - PlanTestCase( - description="1 ready, 1 downloading (and state up) -> no-op", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=ReadyRunnerStatus(), - downloaded=True, - device_rank=0, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=DownloadingRunnerStatus( - download_progress=DownloadPending(node_id=NODE_A) - ), - downloaded=False, - device_rank=1, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A))}, - tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)}, - ), - expected_op=None - ), - - PlanTestCase( - description="2 ready runners (and state up) -> expect RunnerUpOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=ReadyRunnerStatus(), - downloaded=True, - device_rank=0, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=ReadyRunnerStatus(), - downloaded=True, - device_rank=1, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()}, - tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)}, - ), - expected_op=RunnerUpOp(runner_id=RUNNER_1_ID) - ), - - PlanTestCase( - description="loaded runner (and state down) -> expect RunnerDownOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.INACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: LoadedRunnerStatus()}, - tasks={}, - ), - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), - ), - - PlanTestCase( - description="failed runner (and state down) -> expect RunnerDownOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=FailedRunnerStatus(), - downloaded=True, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.INACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: FailedRunnerStatus()}, - tasks={}, - ), - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), - ), - - PlanTestCase( - description="loaded runner, model present, task pending -> expect ExecuteTaskOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - ) - ], - state=State( - node_status={NODE_A: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1) - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: LoadedRunnerStatus()}, - tasks={ - TASK_1_ID: ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=ChatCompletionTaskParams( - model=str(MODEL_A_ID), - messages=[ - ChatCompletionMessage( - role="user", - content="Hello, world!" - ) - ] - ), - instance_id=INSTANCE_1_ID - ) - }, - ), - expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=ChatCompletionTaskParams( - model=str(MODEL_A_ID), - messages=[ChatCompletionMessage(role="user", content="Hello, world!")] - ), - )), - ), - - PlanTestCase( - # We should only run rank 0 once all other ranks are running. - description="two loaded runners & task, i'm rank 0 -> no-op", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - device_rank=0, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - device_rank=1, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()}, - tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)}, - ), - expected_op=None - ), - - PlanTestCase( - description="two loaded runners & task, i'm rank 1 -> expect ExecuteTaskOp on rank 1", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - device_rank=1, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - device_rank=0, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=1, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=0, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()}, - tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)}, - ), - expected_op=ExecuteTaskOp( - runner_id=RUNNER_1_ID, - task=ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_params=ChatCompletionTaskParams( - model=str(MODEL_A_ID), - messages=[ChatCompletionMessage(role="user", content="Hello, world!")], - ), - task_status=TaskStatus.PENDING, - ), - ), - ), - - PlanTestCase( - description="rank 1 loaded, rank 0 ready, i'm rank 0 -> expect ExecuteTaskOp on rank 0", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - device_rank=0, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=RunningRunnerStatus(), - downloaded=True, - device_rank=1, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Running}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: RunningRunnerStatus()}, - tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)}, - ), - expected_op=ExecuteTaskOp( - runner_id=RUNNER_1_ID, - task=ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_params=ChatCompletionTaskParams( - model=str(MODEL_A_ID), - messages=[ChatCompletionMessage(role="user", content="Hello, world!")], - ), - task_status=TaskStatus.PENDING, - ), - ), - ), - - PlanTestCase( - description="other runner failed -> RunnerDownOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - device_rank=0, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=FailedRunnerStatus(), - downloaded=True, - device_rank=1, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: FailedRunnerStatus()}, - ), - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID) - ), - - PlanTestCase( - description="this runner failed (1 node) -> RunnerDownOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=FailedRunnerStatus(), - downloaded=True, - device_rank=0, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1), - }, - node_to_runner={NODE_A: RUNNER_1_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: FailedRunnerStatus()}, - ), - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID) - ), - - - PlanTestCase( - description="this runner failed (2 nodes) -> no-op", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=FailedRunnerStatus(), - downloaded=True, - device_rank=0, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=LoadedRunnerStatus(), - downloaded=True, - device_rank=1, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()}, - ), - expected_op=None - ), - - PlanTestCase( - description="this node failed, other node spun down -> RunnerDownOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=FailedRunnerStatus(), - downloaded=True, - device_rank=0, - ), - InProcessRunner( - runner_id=RUNNER_2_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=ReadyRunnerStatus(), - downloaded=True, - device_rank=1, - ), - ], - state=State( - node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle}, - instances={ - INSTANCE_1_ID: Instance( - instance_type=InstanceStatus.ACTIVE, - instance_id=INSTANCE_1_ID, - shard_assignments=ShardAssignments( - model_id=MODEL_A_ID, - runner_to_shard={ - RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2), - RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2) - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID} - ), - hosts=[] - ) - }, - runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()}, - ), - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID) - ), - - ] - - -# --------------------------------------------------------------------------- -# Parametrised test -# --------------------------------------------------------------------------- - - -# Pre-compute readable identifiers for each case to avoid lambda typing issues. -@pytest.mark.parametrize( - "case", - # We use a factory to delay test case generation until tmp_path is available. - [pytest.param(c, id=c.id()) for c in _get_test_cases(Path(tempfile.TemporaryDirectory().name))], -) -def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: - """Exercise Worker.plan across declarative scenarios.""" - - print(f"----- case: {case.description}") - - # Regenerate test cases with the actual tmp_path fixture - test_cases = {c.description: c for c in _get_test_cases(tmp_path)} - case = test_cases[case.description] - - node_id = NODE_A - - logger = logging.getLogger("test_worker_plan") - shard_downloader = NoopShardDownloader() - worker = Worker(node_id=node_id, shard_downloader=shard_downloader, worker_events=None, global_events=None, logger=logger) - - path_downloaded_map: dict[str, bool] = {} - - runner_config: InProcessRunner - for runner_config in case.in_process_runners: - - model_path = tmp_path / f"model_for_runner_{runner_config.runner_id}" - model_path.mkdir(exist_ok=True, parents=True) - - if len(case.state.instances) == 1: - instance_id = next(iter(case.state.instances)) - - shard_assignments = case.state.instances[instance_id].shard_assignments - shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id] - - # Only add this runner if it belongs to our node - runner_node = None - for node, runner in shard_assignments.node_to_runner.items(): - if runner == runner_config.runner_id: - runner_node = node - break - - if runner_node != node_id: - # This runner belongs to a different node, skip it - continue - - elif len(case.state.instances) == 0: - shard_metadata = PipelineShardMetadata( - device_rank=runner_config.device_rank, - world_size=1, - model_meta=make_model_meta(runner_config.model_id), - start_layer=0, - end_layer=1, - n_layers=1, - ) - else: - raise Exception('test_worker_plan not currently designed to have more than 1 instance.') - - - assigned_runner = AssignedRunner( - runner_id=runner_config.runner_id, - instance_id=runner_config.instance_id, - shard_metadata=shard_metadata, - hosts=[], - status=runner_config.status, - runner=None, - is_downloaded=runner_config.downloaded - ) - worker.assigned_runners[runner_config.runner_id] = assigned_runner - path_downloaded_map[str(build_model_path(shard_metadata.model_meta.model_id))] = runner_config.downloaded - - op = worker.plan(case.state) - assert op == case.expected_op diff --git a/worker/tests/test_worker_plan_utils.py b/worker/tests/test_worker_plan_utils.py deleted file mode 100644 index 84d92ab0..00000000 --- a/worker/tests/test_worker_plan_utils.py +++ /dev/null @@ -1,195 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from pathlib import Path -from typing import Final, List, Optional - -from shared.models.model_cards import MODEL_CARDS, ModelCard -from shared.types.common import CommandId, NodeId -from shared.types.models import ModelId, ModelMetadata -from shared.types.state import State -from shared.types.tasks import TaskId -from shared.types.worker.common import InstanceId, NodeStatus, RunnerId -from shared.types.worker.downloads import DownloadOngoing, DownloadProgressData -from shared.types.worker.instances import Instance, InstanceStatus -from shared.types.worker.ops import RunnerOp -from shared.types.worker.runners import ( - AssignedRunnerStatus, - DownloadingRunnerStatus, - RunnerStatus, - ShardAssignments, -) -from shared.types.worker.shards import PipelineShardMetadata - -NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") -NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") - -# Define constant IDs for deterministic test cases -RUNNER_1_ID: Final[RunnerId] = RunnerId("cccccccc-aaaa-4aaa-8aaa-aaaaaaaaaaaa") -INSTANCE_1_ID: Final[InstanceId] = InstanceId() -RUNNER_2_ID: Final[RunnerId] = RunnerId("dddddddd-aaaa-4aaa-8aaa-aaaaaaaaaaaa") -INSTANCE_2_ID: Final[InstanceId] = InstanceId() -MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' -MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' -TASK_1_ID: Final[TaskId] = TaskId() -COMMAND_1_ID: Final[CommandId] = CommandId() - -@dataclass(slots=True, frozen=True) -class InProcessRunner: - """Minimal description of a runner's in-process state.""" - - runner_id: RunnerId - instance_id: InstanceId - model_id: ModelId - status: RunnerStatus - downloaded: bool - device_rank: int = 0 - - -@dataclass(slots=True, frozen=True) -class PlanTestCase: - """Table-driven description of an entire planning scenario.""" - - description: str - state: State - in_process_runners: List[InProcessRunner] - expected_op: Optional[RunnerOp] - - def id(self) -> str: # noqa: D401 - return self.description.replace(" ", "_") - - -def make_shard_metadata(device_rank: int, world_size: int, model_id: ModelId = MODEL_A_ID) -> PipelineShardMetadata: - """Create PipelineShardMetadata with proper layer assignments based on device_rank and world_size.""" - total_layers = world_size # For simplicity in tests, total_layers = world_size - - if world_size == 1: - start_layer = 0 - end_layer = 1 - n_layers = 1 - else: - # For multi-device setup, each device gets one layer - start_layer = device_rank - end_layer = device_rank + 1 - n_layers = total_layers - - return PipelineShardMetadata( - device_rank=device_rank, - world_size=world_size, - model_meta=make_model_meta(model_id), - start_layer=start_layer, - end_layer=end_layer, - n_layers=n_layers, - ) - - -def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus: - """Factory for a *Downloading* status with placeholder progress.""" - return DownloadingRunnerStatus( - download_progress=DownloadOngoing( - node_id=node_id, - download_progress=DownloadProgressData(total_bytes=1, downloaded_bytes=0), - ) - ) - -def make_model_meta( - model_id: str -) -> ModelMetadata: - model_card: ModelCard - for card in MODEL_CARDS.values(): - if card.model_id == model_id: - model_card = card - - return ModelMetadata( - model_id=model_id, - pretty_name=model_card.model_id, - storage_size_kilobytes=10**6, - n_layers=16, - ) - - raise Exception(f'Unknown model_id passed: {model_id}') - - ## Alternatively, if we are ok for this method to be async: - # await _get_model_meta(model_id) - - -def create_worker_state( - *, - node_id: NodeId, - runner_configs: list[tuple[RunnerId, InstanceId, ModelId]], - tmp_path: Path, -) -> State: - """Create a test `State` based on a list of runner configurations.""" - instances: dict[InstanceId, Instance] = {} - for runner_id, instance_id, model_id in runner_configs: - model_path = tmp_path / f"model_for_runner_{runner_id}" - model_path.mkdir(exist_ok=True, parents=True) - - shard_metadata = PipelineShardMetadata( - device_rank=0, - world_size=1, - model_meta=make_model_meta(model_id), - start_layer=0, - end_layer=1, - n_layers=1, - ) - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={runner_id: shard_metadata}, - node_to_runner={node_id: runner_id}, - ) - instance = Instance( - instance_id=instance_id, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=[], - ) - instances[instance_id] = instance - - return State( - node_status={node_id: NodeStatus.Idle}, - instances=instances, - runners={runner_id: AssignedRunnerStatus() for runner_id, _, _ in runner_configs}, - tasks={}, - ) - - -def make_instance( - instance_id: InstanceId, - model_id: ModelId, - tmp_path: Path, - runner_specs: list[tuple[RunnerId, NodeId, int]], -) -> Instance: - """Creates an instance with one or more runners.""" - runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {} - node_to_runner: dict[NodeId, RunnerId] = {} - world_size = len(runner_specs) - - for runner_id, node_id, device_rank in runner_specs: - model_path = tmp_path / f"model_for_runner_{runner_id}" - model_path.mkdir(exist_ok=True, parents=True) - - shard_metadata = PipelineShardMetadata( - device_rank=device_rank, - world_size=world_size, - model_meta=make_model_meta(model_id), - start_layer=0, - end_layer=1, - n_layers=1, - ) - runner_to_shard[runner_id] = shard_metadata - node_to_runner[node_id] = runner_id - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard=runner_to_shard, - node_to_runner=node_to_runner, - ) - return Instance( - instance_id=instance_id, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=[], - ) - -### For worker plan tests \ No newline at end of file diff --git a/worker/worker.py b/worker/worker.py new file mode 100644 index 00000000..5c874c6f --- /dev/null +++ b/worker/worker.py @@ -0,0 +1,415 @@ +import asyncio +import logging +import time +from asyncio import Queue +from functools import partial +from time import process_time +from typing import AsyncGenerator, Optional + +from shared.db.sqlite import AsyncSQLiteEventStorage +from shared.types.common import NodeId +from shared.types.events import ( + ChunkGenerated, + Event, + InstanceDeleted, + RunnerDeleted, + RunnerStatusUpdated, + TaskFailed, + TaskStateUpdated, +) +from shared.types.state import State +from shared.types.tasks import TaskId, TaskStatus +from shared.types.worker.common import RunnerId +from shared.types.worker.downloads import ( + DownloadCompleted, + DownloadFailed, + DownloadOngoing, + DownloadPending, + DownloadProgressData, +) +from shared.types.worker.ops import ( + AssignRunnerOp, + ExecuteTaskOp, + RunnerDownOp, + RunnerFailedOp, + RunnerOp, + RunnerOpType, + RunnerUpOp, + UnassignRunnerOp, +) +from shared.types.worker.runners import ( + DownloadingRunnerStatus, + FailedRunnerStatus, + InactiveRunnerStatus, + LoadedRunnerStatus, + RunningRunnerStatus, +) +from shared.types.worker.shards import ShardMetadata +from worker.common import AssignedRunner +from worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader +from worker.runner.runner_supervisor import RunnerSupervisor + + +class Worker: + def __init__( + self, + node_id: NodeId, + logger: logging.Logger, + shard_downloader: ShardDownloader, + worker_events: AsyncSQLiteEventStorage | None, + global_events: AsyncSQLiteEventStorage | None, + ): + self.node_id: NodeId = node_id + self.state: State = State() + self.shard_downloader: ShardDownloader = shard_downloader + self.worker_events: AsyncSQLiteEventStorage | None = worker_events # worker_events is None in some tests. + self.global_events: AsyncSQLiteEventStorage | None = global_events + self.logger: logging.Logger = logger + + self.assigned_runners: dict[RunnerId, AssignedRunner] = {} + self._task: asyncio.Task[None] | None = None + + ## Op Executors + + async def _execute_assign_op( + self, op: AssignRunnerOp + ) -> AsyncGenerator[Event, None]: + ''' + A runner has been assigned. We need to also ensure that it's downloaded. + This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state. + ''' + self.assigned_runners[op.runner_id] = AssignedRunner( + runner_id=op.runner_id, + instance_id=op.instance_id, + shard_metadata=op.shard_metadata, + hosts=op.hosts, + status=DownloadingRunnerStatus( + download_progress=DownloadPending( + node_id=self.node_id + ) + ), + runner=None, + ) + + assigned_runner = self.assigned_runners[op.runner_id] + initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata) + + if initial_progress.status == "complete": + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadCompleted( + node_id=self.node_id + ) + ) + yield assigned_runner.status_update_event() + + assigned_runner.status = InactiveRunnerStatus() + yield assigned_runner.status_update_event() + + return + else: + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadOngoing( + node_id=self.node_id, + download_progress=DownloadProgressData( + total_bytes=initial_progress.total_bytes, + downloaded_bytes=initial_progress.downloaded_bytes + ) + ) + ) + yield assigned_runner.status_update_event() + + # Download it! + # TODO: we probably want download progress as part of a callback that gets passed to the downloader. + download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue() + def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None: + download_progress_queue.put_nowait(progress) + + + self.shard_downloader.on_progress(download_progress_callback) + + asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata)) + + # TODO: Dynamic timeout, timeout on no packet update received. + timeout_secs = 10 * 60 + start_time = process_time() + last_yield_progress = start_time + while process_time() - start_time < timeout_secs: + progress: RepoDownloadProgress = await download_progress_queue.get() + if progress.status == "complete": + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadCompleted( + node_id=self.node_id, + ) + ) + yield assigned_runner.status_update_event() + + assigned_runner.status = InactiveRunnerStatus() + yield assigned_runner.status_update_event() + + break + elif progress.status == "in_progress": + if process_time() - last_yield_progress > 1: + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadOngoing( + node_id=self.node_id, + download_progress=DownloadProgressData( + total_bytes=progress.total_bytes, + downloaded_bytes=progress.downloaded_bytes, + ) + ) + ) + yield assigned_runner.status_update_event() + + last_yield_progress = process_time() + else: + assigned_runner.status = DownloadingRunnerStatus( + download_progress=DownloadFailed( + node_id=self.node_id, + error_message=f"Timeout downloading model: {op.shard_metadata.model_meta.model_id}" + ) + ) + yield assigned_runner.status_update_event() + + async def _execute_unassign_op( + self, op: UnassignRunnerOp + ) -> AsyncGenerator[Event, None]: + if op.runner_id not in self.assigned_runners: + return + + # We can try to do a graceful shutdown of the runner. + runner: RunnerSupervisor | None = self.assigned_runners[op.runner_id].runner + if runner is not None: + await runner.astop() + + # This is all we really need: + del self.assigned_runners[op.runner_id] + yield RunnerDeleted(runner_id=op.runner_id) + + return + yield + + async def _execute_runner_up_op( + self, op: RunnerUpOp, initialize_timeout: Optional[float] = None + ) -> AsyncGenerator[Event, None]: + assigned_runner = self.assigned_runners[op.runner_id] + + # TODO: This should be dynamic, based on the size of the model. + if not initialize_timeout: + gigabytes_per_second = 10 + kilobytes_per_second = gigabytes_per_second * 1024 * 1024 + + shard = assigned_runner.shard_metadata + weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes + + initialize_timeout = weights_size_kb / kilobytes_per_second + 120.0 # Add a constant 120.0 to ensure connection can be made as well + + self.logger.info(f"initialize_timeout: {initialize_timeout}") + + try: + assigned_runner.runner = await asyncio.wait_for( + RunnerSupervisor.create( + model_shard_meta=assigned_runner.shard_metadata, + hosts=assigned_runner.hosts, + logger=self.logger, + ), + timeout=initialize_timeout, + ) + except TimeoutError as e: + import traceback + + tb = traceback.format_exc() + e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}") + async for event in self._fail_runner(e=e, runner_id=op.runner_id): + yield event + return + + if assigned_runner.runner.healthy: + assigned_runner.status = LoadedRunnerStatus() + else: + assigned_runner.status = FailedRunnerStatus() + yield self.assigned_runners[op.runner_id].status_update_event() + + async def _execute_runner_down_op( + self, op: RunnerDownOp + ) -> AsyncGenerator[Event, None]: + assigned_runner = self.assigned_runners[op.runner_id] + + if isinstance(assigned_runner.runner, RunnerSupervisor): + await assigned_runner.runner.astop() + + assigned_runner.runner = None + + assigned_runner.status = InactiveRunnerStatus() + yield assigned_runner.status_update_event() + return + + async def _execute_runner_failed_op( + self, op: RunnerFailedOp + ) -> AsyncGenerator[Event, None]: + ''' + We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down. + ''' + assigned_runner = self.assigned_runners[op.runner_id] + + assigned_runner.status = FailedRunnerStatus() + yield self.assigned_runners[op.runner_id].status_update_event() + + + async def _execute_task_op( + self, op: ExecuteTaskOp + ) -> AsyncGenerator[Event, None]: + ''' + This is the entry point for a chat completion starting. + While there is only one execute function, it will get called in different ways for runner 0 and runner [1, 2, 3, ...]. + Runners [1, 2, 3, ...] will run this method when a task is in 'pending' state. + Runner 0 will run this method when a task is in 'running' state. + TODO: How do we handle the logic of ensuring that n-1 nodes have started their execution before allowing the 0'th runner to start? + This is still a little unclear to me. + ''' + assigned_runner = self.assigned_runners[op.runner_id] + + async def inner_execute(queue: asyncio.Queue[Event]) -> None: + async def running_callback(queue: asyncio.Queue[Event]) -> None: + # Called when the MLX process has been kicked off + assigned_runner.status = RunningRunnerStatus() + await queue.put(assigned_runner.status_update_event()) + + if assigned_runner.shard_metadata.device_rank == 0: + await queue.put(TaskStateUpdated( + task_id=op.task.task_id, + task_status=TaskStatus.RUNNING, + )) + + try: + assert assigned_runner.runner is not None + assert assigned_runner.runner.healthy + + async for chunk in assigned_runner.runner.stream_response( + task=op.task, + request_started_callback=partial(running_callback, queue)): + if assigned_runner.shard_metadata.device_rank == 0: + await queue.put(ChunkGenerated( + # todo: at some point we will no longer have a bijection between task_id and row_id. + # So we probably want to store a mapping between these two in our Worker object. + command_id=chunk.command_id, + chunk=chunk + )) + + if assigned_runner.shard_metadata.device_rank == 0: + await queue.put(TaskStateUpdated( + task_id=op.task.task_id, + task_status=TaskStatus.COMPLETE, + )) + + # After a successful inference: + assigned_runner.status = LoadedRunnerStatus() + await queue.put(assigned_runner.status_update_event()) + + + except Exception as e: + # An exception occurs in the runner supervisor + self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}') + async for event in self._fail_task(e, op.runner_id, op.task.task_id): + await queue.put(event) + + queue: Queue[Event] = asyncio.Queue() + task = asyncio.create_task(inner_execute(queue)) + + # TODO: Initial (prefil) timeout can be dynamic + # model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes + + try: + # Yield items from the queue + # timeout = 30. + timeout = 3. + while True: + item: Event = await asyncio.wait_for(queue.get(), timeout=timeout) + yield item + timeout = 2. + if isinstance(item, RunnerStatusUpdated) and isinstance( + item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus) + ): + if isinstance(item.runner_status, LoadedRunnerStatus): + assigned_runner.failures = [] + + break + except TimeoutError as e: + # Runner supervisor doesn't respond in time; so we put the runner & task into a failed state + self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.') + async for event in self._fail_task(e, op.runner_id, op.task.task_id): + yield event + finally: + # Ensure the task is cleaned up + try: + await asyncio.wait_for(task, timeout=5) + except asyncio.TimeoutError: + self.logger.warning("Timed out waiting for task cleanup after inference execution.") + + + ## Operation Planner + + async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]: + ## It would be great if we can get rid of this async for ... yield pattern. + match op.op_type: + case RunnerOpType.ASSIGN_RUNNER: + event_generator = self._execute_assign_op(op) + case RunnerOpType.UNASSIGN_RUNNER: + event_generator = self._execute_unassign_op(op) + case RunnerOpType.RUNNER_UP: + event_generator = self._execute_runner_up_op(op) + case RunnerOpType.RUNNER_DOWN: + event_generator = self._execute_runner_down_op(op) + case RunnerOpType.RUNNER_FAILED: + event_generator = self._execute_runner_failed_op(op) + case RunnerOpType.CHAT_COMPLETION: + event_generator = self._execute_task_op(op) + + async for event in event_generator: + yield event + + + async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]: + if runner_id in self.assigned_runners: + assigned_runner = self.assigned_runners[runner_id] + + assigned_runner.runner = None + assigned_runner.status = FailedRunnerStatus(error_message=str(e)) + assigned_runner.failures.append( + ( + time.time(), + e + ) + ) + + # Reset failure count back to 0 when succesful + if len(assigned_runner.failures) >= 3: + # Too many retries. We will emit a DeleteInstance + yield InstanceDeleted( + instance_id=assigned_runner.instance_id + ) + + yield assigned_runner.status_update_event() + + + async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]: + if runner_id in self.assigned_runners: + yield TaskStateUpdated( + task_id=task_id, + task_status=TaskStatus.FAILED, + ) + + yield TaskFailed( + task_id=task_id, + error_type=str(type(e)), + error_message=str(e) + ) + + async for event in self._fail_runner(e, runner_id): + yield event + + + async def event_publisher(self, event: Event) -> None: + assert self.worker_events is not None + await self.worker_events.append_events([event], self.node_id) + self.logger.info(f"published event: {event}") +