fix libp2p + other prs that were wrongly overwritten before (111,112,117,118,1119 + misc commits from Alex)

Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
Co-authored-by: Seth Howes <71157822+sethhowes@users.noreply.github.com>
Co-authored-by: Matt Beton <matthew.beton@gmail.com>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
Gelu Vrabie
2025-07-31 20:36:47 +01:00
committed by GitHub
parent 2031d9481d
commit 0e32599e71
60 changed files with 4048 additions and 2857 deletions

View File

@@ -6,5 +6,5 @@ runs:
using: "composite"
steps:
- name: Format code
run: nix develop -c just fmt
run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just fmt
shell: bash

View File

@@ -6,5 +6,5 @@ runs:
using: "composite"
steps:
- name: Lint check
run: nix develop -c just lint-check
run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just lint-check
shell: bash

View File

@@ -6,5 +6,5 @@ runs:
using: "composite"
steps:
- name: Lint code
run: nix develop -c just lint
run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just lint
shell: bash

View File

@@ -6,5 +6,5 @@ runs:
using: "composite"
steps:
- name: Regenerate protobufs
run: nix develop -c just regenerate-protobufs
run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just regenerate-protobufs
shell: bash

View File

@@ -1,12 +1,12 @@
name: Type Check
description: "Run static type checker"
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix develop -c just sync
nix develop -c just check
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

12
.github/actions/unit-test/action.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
name: Unit Test
description: "Run unit tests"
runs:
using: "composite"
steps:
- name: Run unit tests
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync-clean
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just test-fast
shell: bash

360
.github/workflows/e2e_test.yml vendored Normal file
View File

@@ -0,0 +1,360 @@
name: macOS System Info
on:
workflow_dispatch: # This allows manual triggering
# push:
# branches: [ '*' ]
# tags: [ '*' ]
jobs:
master:
runs-on: ['self-hosted', 'macOS']
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: true
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Reset databases
run: |
if [ -d ~/.exo ]; then
rm -rf ~/.exo/*.db*
fi
- name: Setup EXO_HOME and API_PORT
run: |
EXO_HOME=$(mktemp -d -t exo-e2e-master-XXXXXXXX)
# Generate random port (macOS compatible method)
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
echo "API_PORT=$API_PORT" >> $GITHUB_ENV
echo "Created EXO_HOME: $EXO_HOME"
echo "Generated API_PORT: $API_PORT"
echo "Verifying API_PORT is set: $API_PORT"
shell: bash
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Print macOS system information
run: |
echo "=== macOS System Information ==="
echo "OS Version:"
sw_vers
echo -e "\n=== Memory Information ==="
system_profiler SPMemoryDataType
echo -e "\n=== Memory Usage Summary ==="
vm_stat | perl -ne '/page size of (\d+)/ and $size=$1; /Pages free: (\d+)/ and printf "Free Memory: %.2f GB\n", $1 * $size / 1024 / 1024 / 1024'
top -l 1 -s 0 | grep PhysMem
echo -e "\n=== CPU Information ==="
sysctl -n machdep.cpu.brand_string
system_profiler SPHardwareDataType | grep -E "Cores|Processors"
echo -e "\n=== Disk Space ==="
df -h /
# - name: Setup Hugging Face token
# run: |
# mkdir -p ~/.cache/huggingface
# echo "${{ secrets.HF_TOKEN }}" > ~/.cache/huggingface/token
- name: Sync dependencies
run: |
echo "Running just sync-clean to ensure clean dependencies..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync-clean
shell: bash
- name: Build forwarder
run: |
echo "Building Go forwarder binary..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just build-forwarder
shell: bash
- name: Start node (master)
run: |
echo "Starting master node with debug enabled..."
echo "Environment check - API_PORT: '$API_PORT'"
echo "Environment check - EXO_HOME: '$EXO_HOME'"
if [ -z "$API_PORT" ]; then
echo "ERROR: API_PORT is not set!"
exit 1
fi
# Run with Python unbuffered output and maximum debug level
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_HOME=$EXO_HOME API_PORT=$API_PORT PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run master/main.py" > /tmp/master_node.log 2>&1 &
MASTER_PID=$!
echo "Started master node in background with PID: $MASTER_PID"
echo "Log file: /tmp/master_node.log"
echo "Starting worker node..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_HOME=$EXO_HOME PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run worker/main.py" > /tmp/worker_node.log 2>&1 &
WORKER_PID=$!
echo "Started worker node in background with PID: $WORKER_PID"
echo "Log file: /tmp/worker_node.log"
for i in {1..30}; do
echo "Attempt $i: Checking if master node is ready..."
if curl -s http://localhost:$API_PORT/state > /dev/null 2>&1; then
echo "Master node is ready!"
break
fi
if [ $i -eq 30 ]; then
echo "Master node failed to start within 30 seconds. Checking logs..."
echo "=== Master node log ==="
cat /tmp/master_node.log || echo "No master log file found"
echo "=== Worker node log ==="
cat /tmp/worker_node.log || echo "No worker log file found"
exit 1
fi
sleep 1
done
# wait for master to have a COMPLETE or FAILED task in the state
for i in {1..30}; do
if curl -s http://localhost:$API_PORT/state | jq -r '.tasks | any(.task_status == "COMPLETE" or .task_status == "FAILED")' > 0; then
echo "Master node has a COMPLETE or FAILED task in the state"
break
fi
sleep 1
done
echo "=== Master node log ==="
cat /tmp/master_node.log || echo "No master log file found"
echo "=== Worker node log ==="
cat /tmp/worker_node.log || echo "No worker log file found"
- name: Cleanup EXO_HOME
run: |
echo "Cleaning up EXO_HOME: $EXO_HOME"
rm -rf "$EXO_HOME"
shell: bash
if: always()
worker:
runs-on: ['self-hosted', 'macOS']
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: true
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Reset databases
run: |
if [ -d ~/.exo ]; then
rm -rf ~/.exo/*.db*
fi
- name: Setup EXO_HOME and API_PORT
run: |
EXO_HOME=$(mktemp -d -t exo-e2e-worker-XXXXXXXX)
# Generate random port (macOS compatible method)
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
echo "API_PORT=$API_PORT" >> $GITHUB_ENV
echo "Created EXO_HOME: $EXO_HOME"
echo "Generated API_PORT: $API_PORT"
echo "Verifying API_PORT is set: $API_PORT"
shell: bash
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Print macOS system information
run: |
echo "=== macOS System Information ==="
echo "OS Version:"
sw_vers
echo -e "\n=== Memory Information ==="
system_profiler SPMemoryDataType
echo -e "\n=== Memory Usage Summary ==="
vm_stat | perl -ne '/page size of (\d+)/ and $size=$1; /Pages free: (\d+)/ and printf "Free Memory: %.2f GB\n", $1 * $size / 1024 / 1024 / 1024'
top -l 1 -s 0 | grep PhysMem
echo -e "\n=== CPU Information ==="
sysctl -n machdep.cpu.brand_string
system_profiler SPHardwareDataType | grep -E "Cores|Processors"
echo -e "\n=== Disk Space ==="
df -h /
# - name: Setup Hugging Face token
# run: |
# mkdir -p ~/.cache/huggingface
# echo "${{ secrets.HF_TOKEN }}" > ~/.cache/huggingface/token
- name: Sync dependencies
run: |
echo "Running just sync-clean to ensure clean dependencies..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync-clean
shell: bash
- name: Build forwarder
run: |
echo "Building Go forwarder binary..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just build-forwarder
shell: bash
- name: Start node (replica)
run: |
echo "Starting master node with debug enabled..."
echo "Environment check - API_PORT: '$API_PORT'"
echo "Environment check - EXO_HOME: '$EXO_HOME'"
if [ -z "$API_PORT" ]; then
echo "ERROR: API_PORT is not set!"
exit 1
fi
# Run with Python unbuffered output and maximum debug level
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_RUN_AS_REPLICA=1 EXO_HOME=$EXO_HOME API_PORT=$API_PORT PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run master/main.py" > /tmp/master_node.log 2>&1 &
MASTER_PID=$!
echo "Started master node in background with PID: $MASTER_PID"
echo "Log file: /tmp/master_node.log"
echo "Starting worker node..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c "EXO_HOME=$EXO_HOME PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run worker/main.py" > /tmp/worker_node.log 2>&1 &
WORKER_PID=$!
echo "Started worker node in background with PID: $WORKER_PID"
echo "Log file: /tmp/worker_node.log"
echo "Waiting for master node to start on port $API_PORT..."
# Wait for the master node to be ready (up to 30 seconds)
for i in {1..30}; do
echo "Attempt $i: Checking if master node is ready..."
if curl -s http://localhost:$API_PORT/state > /dev/null 2>&1; then
echo "Master node is ready!"
break
fi
if [ $i -eq 30 ]; then
echo "Master node failed to start within 30 seconds. Checking logs..."
echo "=== Master node log ==="
cat /tmp/master_node.log || echo "No master log file found"
echo "=== Worker node log ==="
cat /tmp/worker_node.log || echo "No worker log file found"
exit 1
fi
sleep 1
done
resp=$(curl -X POST http://localhost:$API_PORT/instance -H "Content-Type: application/json" -d '{"model_id": "llama-3.2:1b"}')
echo "Response: $resp"
instance_id=$(echo $resp | jq -r '.instance_id')
echo "Instance ID: $instance_id"
for i in {1..50}; do
resp=$(curl -s -w "%{http_code}" -X GET http://localhost:$API_PORT/instance/$instance_id -H "Content-Type: application/json")
http_code="${resp: -3}"
response_body="${resp%???}"
echo "HTTP Code: $http_code"
echo "Response: $response_body"
if [ "$http_code" == "200" ]; then
instance_status=$(echo $response_body | jq -r '.instance_type')
if [ "$instance_status" == "ACTIVE" ]; then
echo "Instance is ready"
break
fi
elif [ "$http_code" == "404" ]; then
echo "Instance not yet created, waiting..."
else
echo "Unexpected HTTP status: $http_code"
fi
sleep 1
done
resp=$(curl http://localhost:$API_PORT/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "llama-3.2:1b", "messages": [{"role": "user", "content": "What is the meaning of exo?"}], "temperature": 0.7}')
echo "Response: $resp"
resp=$(curl -X DELETE http://localhost:$API_PORT/instance/$instance_id -H "Content-Type: application/json")
echo "Response: $resp"
echo "=== Master node log ==="
cat /tmp/master_node.log || echo "No master log file found"
echo "=== Worker node log ==="
cat /tmp/worker_node.log || echo "No worker log file found"
kill $MASTER_PID
kill $WORKER_PID
- name: Cleanup EXO_HOME
run: |
echo "Cleaning up EXO_HOME: $EXO_HOME"
rm -rf "$EXO_HOME"
shell: bash
if: always()

View File

@@ -12,9 +12,12 @@ on:
jobs:
typecheck:
runs-on: ubuntu-22.04
runs-on: ['self-hosted', 'macOS']
steps:
- uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: true
- name: Configure git user
run: |
@@ -22,23 +25,54 @@ jobs:
git config --local user.name "github-actions bot"
shell: bash
- uses: cachix/install-nix-action@v31
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- uses: ./.github/actions/typecheck
ci:
needs: typecheck
runs-on: ubuntu-22.04
runs-on: ['self-hosted', 'macOS']
permissions:
contents: read
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
lfs: true
- name: Configure git user
run: |
@@ -46,12 +80,67 @@ jobs:
git config --local user.name "github-actions bot"
shell: bash
- uses: cachix/install-nix-action@v31
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Setup EXO_HOME and API_PORT
run: |
EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
# Generate random port (macOS compatible method)
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
echo "API_PORT=$API_PORT" >> $GITHUB_ENV
echo "Created EXO_HOME: $EXO_HOME"
echo "Generated API_PORT: $API_PORT"
shell: bash
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Build forwarder
run: |
echo "Building Go forwarder binary..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just build-forwarder
shell: bash
- uses: ./.github/actions/verify-clean
with:
step: regenerate-protobufs
- uses: ./.github/actions/lint-check
- uses: ./.github/actions/unit-test
- name: Cleanup EXO_HOME
run: |
echo "Cleaning up EXO_HOME: $EXO_HOME"
rm -rf "$EXO_HOME"
shell: bash
if: always()

View File

@@ -1,6 +1,7 @@
import asyncio
import concurrent.futures
import os
import resource
from asyncio import AbstractEventLoop
from typing import Any, Callable
@@ -18,6 +19,8 @@ from shared.types.worker.shards import ShardMetadata
from worker.download.download_utils import build_model_path
from worker.runner.communication import runner_print
# Needed for 8 bit model
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
def mx_barrier():
mx.eval( # type: ignore
@@ -86,6 +89,7 @@ def shard_and_load(model_shard_meta: ShardMetadata) -> tuple[nn.Module, Tokenize
tokenizer = load_tokenizer(model_path)
assert isinstance(tokenizer, TokenizerWrapper)
model = auto_parallel(model, model_shard_meta)
mx.eval(model.parameters()) # type: ignore
# Synchronize processes before generation to avoid timeout
mx_barrier()

View File

@@ -19,6 +19,9 @@ lint-check:
test:
uv run pytest master worker shared engines/*
test-fast:
uv run pytest master shared engines/*
check:
uv run basedpyright --project pyproject.toml

View File

@@ -49,6 +49,9 @@ class DiscoverySupervisor:
send_back_multiaddr = Multiaddr(address=str(e.send_back_addr))
connection_profile = None
if send_back_multiaddr.ipv4_address == local_multiaddr.ipv4_address:
return
topology_edge_created = TopologyEdgeCreated(edge=Connection(
local_node_id=local_node_id,
send_back_node_id=send_back_node_id,
@@ -56,7 +59,7 @@ class DiscoverySupervisor:
send_back_multiaddr=send_back_multiaddr,
connection_profile=connection_profile
))
self.logger.error(
self.logger.info(
msg=f"CONNECTED CALLBACK: {local_node_id} -> {send_back_node_id}, {local_multiaddr} -> {send_back_multiaddr}")
await self.global_events.append_events(
[topology_edge_created],

View File

@@ -111,6 +111,7 @@ class ForwarderSupervisor:
env_vars["FORWARDER_NODE_ID"] = str(self.node_id)
self._process = await asyncio.create_subprocess_exec(
str(self._binary_path),
"--events-db", str(EXO_WORKER_EVENT_DB),
f'{pairs}',
stdout=None,
stderr=None,

View File

@@ -9,7 +9,8 @@ from typing import List
from exo_pyo3_bindings import Keypair
from master.api import start_fastapi_server
from master.discovery_supervisor import DiscoverySupervisor
# from master.discovery_supervisor import DiscoverySupervisor
from master.election_callback import ElectionCallbacks
from master.forwarder_supervisor import ForwarderRole, ForwarderSupervisor
from master.placement import get_instance_placements, get_transition_events
@@ -45,13 +46,13 @@ class Master:
self.command_buffer = command_buffer
self.global_events = global_events
self.worker_events = worker_events
self.discovery_supervisor = DiscoverySupervisor(
node_id_keypair,
node_id,
# TODO: needs to be more general for when we have master election
worker_events if os.getenv('EXO_RUN_AS_REPLICA') in set(['TRUE', 'true', '1']) else global_events,
logger
)
# self.discovery_supervisor = DiscoverySupervisor(
# node_id_keypair,
# node_id,
# # TODO: needs to be more general for when we have master election
# worker_events if os.getenv('EXO_RUN_AS_REPLICA') in set(['TRUE', 'true', '1']) else global_events,
# logger
# )
self.forwarder_supervisor = ForwarderSupervisor(
self.node_id,
forwarder_binary_path=forwarder_binary_path,
@@ -116,7 +117,7 @@ class Master:
await self.event_log_for_writes.append_events(next_events, origin=self.node_id)
# 2. get latest events
events = await self.event_log_for_reads.get_events_since(self.state.last_event_applied_idx)
events = await self.event_log_for_reads.get_events_since(self.state.last_event_applied_idx, ignore_no_op_events=True)
if len(events) == 0:
await asyncio.sleep(0.01)
return
@@ -157,7 +158,7 @@ class Master:
async def main():
logger = logging.getLogger('master_logger')
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))

View File

@@ -114,8 +114,7 @@ def test_remove_connection_still_connected(topology: Topology, node_profile: Nod
topology.remove_connection(connection)
# assert
with pytest.raises(IndexError):
topology.get_connection_profile(connection)
assert topology.get_connection_profile(connection) is None
def test_remove_connection_bridge(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
@@ -130,6 +129,8 @@ def test_remove_connection_bridge(topology: Topology, node_profile: NodePerforma
topology.add_node(Node(node_id=node_a_id, node_profile=node_profile))
topology.add_node(Node(node_id=node_b_id, node_profile=node_profile))
topology.set_master_node_id(master_id)
connection_master_to_a = Connection(
local_node_id=master_id,
send_back_node_id=node_a_id,
@@ -157,11 +158,8 @@ def test_remove_connection_bridge(topology: Topology, node_profile: NodePerforma
assert len(remaining_nodes) == 1
assert remaining_nodes[0].node_id == master_id
with pytest.raises(KeyError):
topology.get_node_profile(node_a_id)
with pytest.raises(KeyError):
topology.get_node_profile(node_b_id)
assert topology.get_node_profile(node_a_id) is None
assert topology.get_node_profile(node_b_id) is None
def test_remove_node_still_connected(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):
@@ -174,8 +172,7 @@ def test_remove_node_still_connected(topology: Topology, node_profile: NodePerfo
topology.remove_node(connection.local_node_id)
# assert
with pytest.raises(KeyError):
topology.get_node_profile(connection.local_node_id)
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, node_profile: NodePerformanceProfile, connection: Connection):

View File

@@ -1,16 +1,17 @@
module forwarder
go 1.23
go 1.23.8
toolchain go1.24.3
replace forwarder/src => ./src
require (
github.com/google/uuid v1.6.0
github.com/libp2p/go-libp2p v0.39.1
github.com/google/uuid v1.6.0
github.com/libp2p/go-libp2p v0.42.1
github.com/libp2p/go-libp2p-pubsub v0.14.2
github.com/mattn/go-sqlite3 v1.14.28
github.com/multiformats/go-multiaddr v0.16.0
github.com/stretchr/testify v1.10.0
)
@@ -18,110 +19,99 @@ require (
github.com/benbjohnson/clock v1.3.5 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/cgroups v1.1.0 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/elastic/gosigar v0.14.3 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
github.com/flynn/noise v1.1.0 // indirect
github.com/francoispqt/gojay v1.2.13 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/google/pprof v0.0.0-20250202011525-fc3143867406 // indirect
github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/huin/goupnp v1.3.0 // indirect
github.com/ipfs/go-cid v0.5.0 // indirect
github.com/ipfs/go-log/v2 v2.5.1 // indirect
github.com/ipfs/go-log/v2 v2.6.0 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/koron/go-ssdp v0.0.5 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/koron/go-ssdp v0.0.6 // indirect
github.com/libp2p/go-buffer-pool v0.1.0 // indirect
github.com/libp2p/go-flow-metrics v0.2.0 // indirect
github.com/libp2p/go-libp2p-asn-util v0.4.1 // indirect
github.com/libp2p/go-msgio v0.3.0 // indirect
github.com/libp2p/go-nat v0.2.0 // indirect
github.com/libp2p/go-netroute v0.2.2 // indirect
github.com/libp2p/go-reuseport v0.4.0 // indirect
github.com/libp2p/go-yamux/v4 v4.0.2 // indirect
github.com/libp2p/go-yamux/v5 v5.0.1 // indirect
github.com/libp2p/zeroconf/v2 v2.2.0 // indirect
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/miekg/dns v1.1.63 // indirect
github.com/miekg/dns v1.1.66 // indirect
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b // indirect
github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect
github.com/multiformats/go-base32 v0.1.0 // indirect
github.com/multiformats/go-base36 v0.2.0 // indirect
github.com/multiformats/go-multiaddr v0.14.0 // indirect
github.com/multiformats/go-multiaddr-dns v0.4.1 // indirect
github.com/multiformats/go-multiaddr-fmt v0.1.0 // indirect
github.com/multiformats/go-multibase v0.2.0 // indirect
github.com/multiformats/go-multicodec v0.9.0 // indirect
github.com/multiformats/go-multicodec v0.9.1 // indirect
github.com/multiformats/go-multihash v0.2.3 // indirect
github.com/multiformats/go-multistream v0.6.0 // indirect
github.com/multiformats/go-multistream v0.6.1 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo/v2 v2.22.2 // indirect
github.com/opencontainers/runtime-spec v1.2.0 // indirect
github.com/onsi/ginkgo/v2 v2.23.4 // indirect
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
github.com/pion/datachannel v1.5.10 // indirect
github.com/pion/dtls/v2 v2.2.12 // indirect
github.com/pion/dtls/v3 v3.0.4 // indirect
github.com/pion/ice/v2 v2.3.37 // indirect
github.com/pion/ice/v4 v4.0.6 // indirect
github.com/pion/interceptor v0.1.37 // indirect
github.com/pion/dtls/v3 v3.0.6 // indirect
github.com/pion/ice/v4 v4.0.10 // indirect
github.com/pion/interceptor v0.1.40 // indirect
github.com/pion/logging v0.2.3 // indirect
github.com/pion/mdns v0.0.12 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.15 // indirect
github.com/pion/rtp v1.8.11 // indirect
github.com/pion/sctp v1.8.35 // indirect
github.com/pion/sdp/v3 v3.0.10 // indirect
github.com/pion/srtp/v3 v3.0.4 // indirect
github.com/pion/rtp v1.8.19 // indirect
github.com/pion/sctp v1.8.39 // indirect
github.com/pion/sdp/v3 v3.0.13 // indirect
github.com/pion/srtp/v3 v3.0.6 // indirect
github.com/pion/stun v0.6.1 // indirect
github.com/pion/stun/v3 v3.0.0 // indirect
github.com/pion/transport/v2 v2.2.10 // indirect
github.com/pion/transport/v3 v3.0.7 // indirect
github.com/pion/turn/v2 v2.1.6 // indirect
github.com/pion/turn/v4 v4.0.0 // indirect
github.com/pion/webrtc/v4 v4.0.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pion/turn/v4 v4.0.2 // indirect
github.com/pion/webrtc/v4 v4.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.20.5 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/prometheus/client_golang v1.22.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.64.0 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.49.0 // indirect
github.com/quic-go/quic-go v0.52.0 // indirect
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 // indirect
github.com/raulk/go-watchdog v1.3.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
go.uber.org/dig v1.18.0 // indirect
go.uber.org/fx v1.23.0 // indirect
go.uber.org/mock v0.5.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/dig v1.19.0 // indirect
go.uber.org/fx v1.24.0 // indirect
go.uber.org/mock v0.5.2 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect
golang.org/x/mod v0.23.0 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
golang.org/x/tools v0.29.0 // indirect
google.golang.org/protobuf v1.36.4 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.34.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/blake3 v1.3.0 // indirect
lukechampine.com/blake3 v1.4.1 // indirect
)
// Remember to run `go mod tidy` after adding dependencies.

View File

@@ -9,8 +9,6 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o=
github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
@@ -20,33 +18,18 @@ github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBT
github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM=
github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw=
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c h1:pFUpOrbxDR6AkioZ1ySsx5yxlDQZ8stG2b88gTPxgJU=
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c/go.mod h1:6UhI8N9EjYm1c2odKpFpAYeR8dsBeM7PtzQhRgxRr9U=
github.com/decred/dcrd/crypto/blake256 v1.0.1 h1:7PltbUIQB7u/FfZ39+DGa/ShuMyJ5ilcvdfma9wOH6Y=
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/decred/dcrd/crypto/blake256 v1.1.0 h1:zPMNGQCm0g4QTY27fOCorQW7EryeQ/U0x++OzVrdms8=
github.com/decred/dcrd/crypto/blake256 v1.1.0/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
@@ -60,12 +43,7 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
@@ -76,18 +54,16 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20250202011525-fc3143867406 h1:wlQI2cYY0BsWmmPPAnxfQ8SDW0S3Jasn+4B8kXFxprg=
github.com/google/pprof v0.0.0-20250202011525-fc3143867406/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a h1://KbezygeMJZCSHH+HgUZiTeSoiuFspbMg1ge+eFj18=
github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY=
@@ -103,8 +79,8 @@ github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc=
github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
github.com/ipfs/go-cid v0.5.0 h1:goEKKhaGm0ul11IHA7I6p1GmKz8kEYniqFopaB5Otwg=
github.com/ipfs/go-cid v0.5.0/go.mod h1:0L7vmeNXpQpUS9vt+yEARkJ8rOg43DF3iPgn4GIN0mk=
github.com/ipfs/go-log/v2 v2.5.1 h1:1XdUzF7048prq4aBjDQQ4SL5RxftpRGdXhNRwKSAlcY=
github.com/ipfs/go-log/v2 v2.5.1/go.mod h1:prSpmC1Gpllc9UYWxDiZDreBYw7zp4Iqp1kOLU9U5UI=
github.com/ipfs/go-log/v2 v2.6.0 h1:2Nu1KKQQ2ayonKp4MPo6pXCjqw1ULc9iohRqWV5EYqg=
github.com/ipfs/go-log/v2 v2.6.0/go.mod h1:p+Efr3qaY5YXpx9TX7MoLCSEZX5boSWj9wh86P5HJa8=
github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
github.com/jbenet/go-temp-err-catcher v0.1.0 h1:zpb3ZH6wIE8Shj2sKS+khgRvf7T7RABoLk/+KKHggpk=
@@ -112,15 +88,14 @@ github.com/jbenet/go-temp-err-catcher v0.1.0/go.mod h1:0kJRvmDZXNMIiJirNPEYfhpPw
github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
github.com/koron/go-ssdp v0.0.5 h1:E1iSMxIs4WqxTbIBLtmNBeOOC+1sCIXQeqTWVnpmwhk=
github.com/koron/go-ssdp v0.0.5/go.mod h1:Qm59B7hpKpDqfyRNWRNr00jGwLdXjDyZh6y7rH6VS0w=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/koron/go-ssdp v0.0.6 h1:Jb0h04599eq/CY7rB5YEqPS83HmRfHP2azkxMN2rFtU=
github.com/koron/go-ssdp v0.0.6/go.mod h1:0R9LfRJGek1zWTjN3JUNlm5INCDYGpRDfAptnct63fI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -134,8 +109,8 @@ github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6
github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg=
github.com/libp2p/go-flow-metrics v0.2.0 h1:EIZzjmeOE6c8Dav0sNv35vhZxATIXWZg6j/C08XmmDw=
github.com/libp2p/go-flow-metrics v0.2.0/go.mod h1:st3qqfu8+pMfh+9Mzqb2GTiwrAGjIPszEjZmtksN8Jc=
github.com/libp2p/go-libp2p v0.39.1 h1:1Ur6rPCf3GR+g8jkrnaQaM0ha2IGespsnNlCqJLLALE=
github.com/libp2p/go-libp2p v0.39.1/go.mod h1:3zicI8Lp7Isun+Afo/JOACUbbJqqR2owK6RQWFsVAbI=
github.com/libp2p/go-libp2p v0.42.1 h1:Rt8+5thie729NQk1gx1h/2t/+VIafWcqR1I+Kvw+UTg=
github.com/libp2p/go-libp2p v0.42.1/go.mod h1:4NGcjbD9OIvFiSRb0XueCO19zJ4kSPK5vkyyOUYmMro=
github.com/libp2p/go-libp2p-asn-util v0.4.1 h1:xqL7++IKD9TBFMgnLPZR6/6iYhawHKHl950SO9L6n94=
github.com/libp2p/go-libp2p-asn-util v0.4.1/go.mod h1:d/NI6XZ9qxw67b4e+NgpQexCIiFYJjErASrYW4PFDN8=
github.com/libp2p/go-libp2p-pubsub v0.14.2 h1:nT5lFHPQOFJcp9CW8hpKtvbpQNdl2udJuzLQWbgRum8=
@@ -144,21 +119,18 @@ github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUI
github.com/libp2p/go-libp2p-testing v0.12.0/go.mod h1:KcGDRXyN7sQCllucn1cOOS+Dmm7ujhfEyXQL5lvkcPg=
github.com/libp2p/go-msgio v0.3.0 h1:mf3Z8B1xcFN314sWX+2vOTShIE0Mmn2TXn3YCUQGNj0=
github.com/libp2p/go-msgio v0.3.0/go.mod h1:nyRM819GmVaF9LX3l03RMh10QdOroF++NBbxAb0mmDM=
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFPuZ8=
github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE=
github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s=
github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU=
github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU=
github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o=
github.com/libp2p/go-yamux/v5 v5.0.1 h1:f0WoX/bEF2E8SbE4c/k1Mo+/9z0O4oC/hWEA+nfYRSg=
github.com/libp2p/go-yamux/v5 v5.0.1/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU=
github.com/libp2p/zeroconf/v2 v2.2.0 h1:Cup06Jv6u81HLhIj1KasuNM/RHHrJ8T7wOTS4+Tv53Q=
github.com/libp2p/zeroconf/v2 v2.2.0/go.mod h1:fuJqLnUwZTshS3U/bMRJ3+ow/v9oid1n0DmyYyNO1Xs=
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk=
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd/go.mod h1:QuCEs1Nt24+FYQEqAAncTDPJIuGs+LxK1MCiFL25pMU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
@@ -166,8 +138,8 @@ github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY=
github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs=
github.com/miekg/dns v1.1.66 h1:FeZXOS3VCVsKnEAd+wBkjMC3D2K+ww66Cq3VnCINuJE=
github.com/miekg/dns v1.1.66/go.mod h1:jGFzBsSNbJw6z1HYut1RKBKHA9PBdxeHrZG8J+gC2WE=
github.com/mikioh/tcp v0.0.0-20190314235350-803a9b46060c h1:bzE/A84HN25pxAuk9Eej1Kz9OUelF97nAc82bDquQI8=
github.com/mikioh/tcp v0.0.0-20190314235350-803a9b46060c/go.mod h1:0SQS9kMwD2VsyFEB++InYyBJroV/FRmBgcydeSUcJms=
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b h1:z78hV3sbSMAUoyUMM0I83AUIT6Hu17AWfgjzIbtrYFc=
@@ -188,34 +160,31 @@ github.com/multiformats/go-base32 v0.1.0/go.mod h1:Kj3tFY6zNr+ABYMqeUNeGvkIC/UYg
github.com/multiformats/go-base36 v0.2.0 h1:lFsAbNOGeKtuKozrtBsAkSVhv1p9D0/qedU9rQyccr0=
github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a1UV0xHgWc0hkp4=
github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo=
github.com/multiformats/go-multiaddr v0.14.0 h1:bfrHrJhrRuh/NXH5mCnemjpbGjzRw/b+tJFOD41g2tU=
github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4=
github.com/multiformats/go-multiaddr v0.16.0 h1:oGWEVKioVQcdIOBlYM8BH1rZDWOGJSqr9/BKl6zQ4qc=
github.com/multiformats/go-multiaddr v0.16.0/go.mod h1:JSVUmXDjsVFiW7RjIFMP7+Ev+h1DTbiJgVeTV/tcmP0=
github.com/multiformats/go-multiaddr-dns v0.4.1 h1:whi/uCLbDS3mSEUMb1MsoT4uzUeZB0N32yzufqS0i5M=
github.com/multiformats/go-multiaddr-dns v0.4.1/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc=
github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E=
github.com/multiformats/go-multiaddr-fmt v0.1.0/go.mod h1:hGtDIW4PU4BqJ50gW2quDuPVjyWNZxToGUh/HwTZYJo=
github.com/multiformats/go-multibase v0.2.0 h1:isdYCVLvksgWlMW9OZRYJEa9pZETFivncJHmHnnd87g=
github.com/multiformats/go-multibase v0.2.0/go.mod h1:bFBZX4lKCA/2lyOFSAoKH5SS6oPyjtnzK/XTFDPkNuk=
github.com/multiformats/go-multicodec v0.9.0 h1:pb/dlPnzee/Sxv/j4PmkDRxCOi3hXTz3IbPKOXWJkmg=
github.com/multiformats/go-multicodec v0.9.0/go.mod h1:L3QTQvMIaVBkXOXXtVmYE+LI16i14xuaojr/H7Ai54k=
github.com/multiformats/go-multicodec v0.9.1 h1:x/Fuxr7ZuR4jJV4Os5g444F7xC4XmyUaT/FWtE+9Zjo=
github.com/multiformats/go-multicodec v0.9.1/go.mod h1:LLWNMtyV5ithSBUo3vFIMaeDy+h3EbkMTek1m+Fybbo=
github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew=
github.com/multiformats/go-multihash v0.2.3 h1:7Lyc8XfX/IY2jWb/gI7JP+o7JEq9hOa7BFvVU9RSh+U=
github.com/multiformats/go-multihash v0.2.3/go.mod h1:dXgKXCXjBzdscBLk9JkjINiEsCKRVch90MdaGiKsvSM=
github.com/multiformats/go-multistream v0.6.0 h1:ZaHKbsL404720283o4c/IHQXiS6gb8qAN5EIJ4PN5EA=
github.com/multiformats/go-multistream v0.6.0/go.mod h1:MOyoG5otO24cHIg8kf9QW2/NozURlkP/rvi2FQJyCPg=
github.com/multiformats/go-multistream v0.6.1 h1:4aoX5v6T+yWmc2raBHsTvzmFhOI8WVOer28DeBBEYdQ=
github.com/multiformats/go-multistream v0.6.1/go.mod h1:ksQf6kqHAb6zIsyw7Zm+gAuVo57Qbq84E27YlYqavqw=
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU=
github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk=
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.2.0 h1:z97+pHb3uELt/yiAWD691HNHQIF07bE7dzrbT927iTk=
github.com/opencontainers/runtime-spec v1.2.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus=
github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8=
github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU=
github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y=
@@ -224,33 +193,29 @@ github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oL
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk=
github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U=
github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg=
github.com/pion/ice/v2 v2.3.37 h1:ObIdaNDu1rCo7hObhs34YSBcO7fjslJMZV0ux+uZWh0=
github.com/pion/ice/v2 v2.3.37/go.mod h1:mBF7lnigdqgtB+YHkaY/Y6s6tsyRyo4u4rPGRuOjUBQ=
github.com/pion/ice/v4 v4.0.6 h1:jmM9HwI9lfetQV/39uD0nY4y++XZNPhvzIPCb8EwxUM=
github.com/pion/ice/v4 v4.0.6/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw=
github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI=
github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y=
github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E=
github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU=
github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4=
github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw=
github.com/pion/interceptor v0.1.40 h1:e0BjnPcGpr2CFQgKhrQisBU7V3GXK6wrfYrGYaU6Jq4=
github.com/pion/interceptor v0.1.40/go.mod h1:Z6kqH7M/FYirg3frjGJ21VLSRJGBXB/KqaTIrdqnOic=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI=
github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90=
github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8=
github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk=
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0=
github.com/pion/rtp v1.8.11 h1:17xjnY5WO5hgO6SD3/NTIUPvSFw/PbLsIJyz1r1yNIk=
github.com/pion/rtp v1.8.11/go.mod h1:8uMBJj32Pa1wwx8Fuv/AsFhn8jsgw+3rUC2PfoBZ8p4=
github.com/pion/sctp v1.8.35 h1:qwtKvNK1Wc5tHMIYgTDJhfZk7vATGVHhXbUDfHbYwzA=
github.com/pion/sctp v1.8.35/go.mod h1:EcXP8zCYVTRy3W9xtOF7wJm1L1aXfKRQzaM33SjQlzg=
github.com/pion/sdp/v3 v3.0.10 h1:6MChLE/1xYB+CjumMw+gZ9ufp2DPApuVSnDT8t5MIgA=
github.com/pion/sdp/v3 v3.0.10/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M=
github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ=
github.com/pion/rtp v1.8.19 h1:jhdO/3XhL/aKm/wARFVmvTfq0lC/CvN1xwYKmduly3c=
github.com/pion/rtp v1.8.19/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE=
github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
github.com/pion/sdp/v3 v3.0.13 h1:uN3SS2b+QDZnWXgdr69SM8KB4EbcnPnPf2Laxhty/l4=
github.com/pion/sdp/v3 v3.0.13/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
github.com/pion/srtp/v3 v3.0.6 h1:E2gyj1f5X10sB/qILUGIkL4C2CqK269Xq167PbGCc/4=
github.com/pion/srtp/v3 v3.0.6/go.mod h1:BxvziG3v/armJHAaJ87euvkhHqWe9I7iiOy50K2QkhY=
github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4=
github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
@@ -259,45 +224,38 @@ github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1A
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q=
github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc=
github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM=
github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA=
github.com/pion/webrtc/v4 v4.0.8 h1:T1ZmnT9qxIJIt4d8XoiMOBrTClGHDDXNg9e/fh018Qc=
github.com/pion/webrtc/v4 v4.0.8/go.mod h1:HHBeUVBAC+j4ZFnYhovEFStF02Arb1EyD4G7e7HBTJw=
github.com/pion/turn/v4 v4.0.2 h1:ZqgQ3+MjP32ug30xAbD6Mn+/K4Sxi3SdNOTFf+7mpps=
github.com/pion/turn/v4 v4.0.2/go.mod h1:pMMKP/ieNAG/fN5cZiN4SDuyKsXtNTr0ccN7IToA1zs=
github.com/pion/webrtc/v4 v4.1.2 h1:mpuUo/EJ1zMNKGE79fAdYNFZBX790KE7kQQpLMjjR54=
github.com/pion/webrtc/v4 v4.1.2/go.mod h1:xsCXiNAmMEjIdFxAYU0MbB3RwRieJsegSB2JZsGN+8U=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4=
github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8=
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.49.0 h1:w5iJHXwHxs1QxyBv1EHKuC50GX5to8mJAxvtnttJp94=
github.com/quic-go/quic-go v0.49.0/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47pA=
github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ=
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg=
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw=
github.com/raulk/go-watchdog v1.3.0 h1:oUmdlHxdkXRJlwfG0O9omj8ukerm8MEQavSiDTEtBsk=
github.com/raulk/go-watchdog v1.3.0/go.mod h1:fIvOnLbF0b0ZwkB9YU4mOW9Did//4vPZtDqv66NfsMU=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM=
@@ -319,10 +277,8 @@ github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b
github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ=
github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk=
github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4=
github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
@@ -331,9 +287,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
@@ -341,7 +294,6 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
@@ -349,23 +301,20 @@ github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw=
go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg=
go.uber.org/fx v1.23.0/go.mod h1:o/D9n+2mLP6v1EG+qsdT1O8wKopYAsqZasju97SDFCU=
go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/fx v1.24.0 h1:wE8mruvpg2kiiL1Vqd0CC+tr0/24XIB10Iwp2lLWzkg=
go.uber.org/fx v1.24.0/go.mod h1:AmDeGyS+ZARGKM4tlH4FY2Jr63VjbEDJHtqXTGP5hbo=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE=
@@ -382,24 +331,22 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc=
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 h1:bsqhLWFR6G6xiQcb+JoGqdKdRU6WzPWmK8E0jxTjzo4=
golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -407,7 +354,6 @@ golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -415,7 +361,6 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
@@ -423,8 +368,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@@ -440,38 +385,31 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180810173357-98c5dad5d1a0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -488,28 +426,25 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE=
golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -530,26 +465,22 @@ google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmE
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM=
google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
lukechampine.com/blake3 v1.3.0 h1:sJ3XhFINmHSrYCgl958hscfIa3bw8x4DqMP3u1YvoYE=
lukechampine.com/blake3 v1.3.0/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k=
lukechampine.com/blake3 v1.4.1 h1:I3Smz7gso8w4/TunLKec6K2fn+kyKtDxr/xcQEN84Wg=
lukechampine.com/blake3 v1.4.1/go.mod h1:QFosUxmjB8mnrWFSNwKmvxHpfY72bmD2tQ0kBMM3kwo=
sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck=
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=

View File

@@ -11,6 +11,7 @@ import (
)
var nodeID = flag.String("node-id", "", "Node ID (defaults to FORWARDER_NODE_ID env var or a new UUID)")
var eventsDBPath = flag.String("events-db", "", "Path to the worker events SQLite database")
func main() {
flag.Parse()
@@ -23,6 +24,12 @@ func main() {
}
log.Printf("Starting forwarder with node ID: %s", id)
// Set the events database path if provided
if *eventsDBPath != "" {
forwarder.SetEventsDBPath(*eventsDBPath)
log.Printf("Using events database: %s", *eventsDBPath)
}
args := flag.Args()
if len(args) == 0 {
log.Fatal("forwarding pairs argument is required as the first positional argument (of the form {source}|{sink}) where source and sink sqlite:db_file:table_name or libp2p:topic")

View File

@@ -0,0 +1,259 @@
package forwarder
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"strconv"
"sync"
"github.com/google/uuid"
"github.com/libp2p/go-libp2p/core/network"
_ "github.com/mattn/go-sqlite3"
"github.com/multiformats/go-multiaddr"
)
var (
eventsDBPath string
eventsDB *sql.DB
eventsDBMu sync.Mutex
)
// SetEventsDBPath sets the path to the events database
func SetEventsDBPath(path string) {
eventsDBMu.Lock()
defer eventsDBMu.Unlock()
eventsDBPath = path
}
// Event types matching Python's _EventType enum
const (
EventTypeTopologyEdgeCreated = "TopologyEdgeCreated"
EventTypeTopologyEdgeDeleted = "TopologyEdgeDeleted"
)
// ConnectionProfile matches Python's ConnectionProfile (optional)
type ConnectionProfile struct {
Throughput float64 `json:"throughput"`
Latency float64 `json:"latency"`
Jitter float64 `json:"jitter"`
}
// Multiaddr matches Python's Multiaddr structure
type Multiaddr struct {
Address string `json:"address"`
IPv4Address string `json:"ipv4_address,omitempty"`
Port int `json:"port,omitempty"`
}
// Connection matches Python's Connection model
type Connection struct {
LocalNodeID string `json:"local_node_id"`
SendBackNodeID string `json:"send_back_node_id"`
LocalMultiaddr Multiaddr `json:"local_multiaddr"`
SendBackMultiaddr Multiaddr `json:"send_back_multiaddr"`
ConnectionProfile *ConnectionProfile `json:"connection_profile"`
}
// TopologyEdgeCreated matches Python's TopologyEdgeCreated event
type TopologyEdgeCreated struct {
EventType string `json:"event_type"`
EventID string `json:"event_id"`
Edge Connection `json:"edge"`
}
// TopologyEdgeDeleted matches Python's TopologyEdgeDeleted event
type TopologyEdgeDeleted struct {
EventType string `json:"event_type"`
EventID string `json:"event_id"`
Edge Connection `json:"edge"`
}
// initEventsDB initializes the events database connection
func initEventsDB() error {
eventsDBMu.Lock()
defer eventsDBMu.Unlock()
if eventsDB != nil {
return nil // Already initialized
}
if eventsDBPath == "" {
return nil // No events DB configured
}
var err error
eventsDB, err = sql.Open("sqlite3", eventsDBPath)
if err != nil {
return fmt.Errorf("failed to open events database: %w", err)
}
// Create table if it doesn't exist (matching Python's schema)
createTableSQL := `
CREATE TABLE IF NOT EXISTS events (
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
origin TEXT NOT NULL,
event_type TEXT NOT NULL,
event_id TEXT NOT NULL,
event_data TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_events_origin ON events(origin);
CREATE INDEX IF NOT EXISTS idx_events_event_type ON events(event_type);
CREATE INDEX IF NOT EXISTS idx_events_created_at ON events(created_at);
`
_, err = eventsDB.Exec(createTableSQL)
if err != nil {
eventsDB.Close()
eventsDB = nil
return fmt.Errorf("failed to create events table: %w", err)
}
return nil
}
// writeEvent writes an event to the database
func writeEvent(eventType string, eventData interface{}) error {
if eventsDB == nil {
if err := initEventsDB(); err != nil {
return err
}
if eventsDB == nil {
return nil // No events DB configured
}
}
// Serialize event data to JSON
jsonData, err := json.Marshal(eventData)
if err != nil {
return fmt.Errorf("failed to marshal event data: %w", err)
}
// Extract event ID from the event data
var eventID string
switch e := eventData.(type) {
case *TopologyEdgeCreated:
eventID = e.EventID
case *TopologyEdgeDeleted:
eventID = e.EventID
default:
eventID = uuid.New().String()
}
// Insert event into database
insertSQL := `INSERT INTO events (origin, event_type, event_id, event_data) VALUES (?, ?, ?, ?)`
_, err = eventsDB.Exec(insertSQL, GetNodeId(), eventType, eventID, string(jsonData))
if err != nil {
return fmt.Errorf("failed to insert event: %w", err)
}
return nil
}
// NotifeeHandler implements the libp2p network.Notifiee interface
type NotifeeHandler struct{}
// Listen is called when network starts listening on an addr
func (n *NotifeeHandler) Listen(net network.Network, ma multiaddr.Multiaddr) {}
// ListenClose is called when network stops listening on an addr
func (n *NotifeeHandler) ListenClose(net network.Network, ma multiaddr.Multiaddr) {}
// Connected is called when a connection is opened
func (n *NotifeeHandler) Connected(net network.Network, conn network.Conn) {
remotePeer := conn.RemotePeer()
localAddr := conn.LocalMultiaddr()
remoteAddr := conn.RemoteMultiaddr()
// Get the actual node IDs (not peer IDs)
localNodeID := GetNodeId()
// For remote node, we need to extract from peer ID or use a mapping
// For now, we'll use the peer ID as a placeholder
// TODO: Implement proper node ID mapping/discovery
remoteNodeID := remotePeer.String()
// Create connection event
event := &TopologyEdgeCreated{
EventType: EventTypeTopologyEdgeCreated,
EventID: uuid.New().String(),
Edge: Connection{
LocalNodeID: localNodeID,
SendBackNodeID: remoteNodeID,
LocalMultiaddr: parseMultiaddr(localAddr),
SendBackMultiaddr: parseMultiaddr(remoteAddr),
ConnectionProfile: nil, // TODO: Add connection profiling if needed
},
}
// Write event to database
if err := writeEvent(EventTypeTopologyEdgeCreated, event); err != nil {
log.Printf("Failed to write edge created event: %v", err)
} else {
log.Printf("Wrote edge created event: %s -> %s", localNodeID, remoteNodeID)
}
}
// Disconnected is called when a connection is closed
func (n *NotifeeHandler) Disconnected(net network.Network, conn network.Conn) {
remotePeer := conn.RemotePeer()
localAddr := conn.LocalMultiaddr()
remoteAddr := conn.RemoteMultiaddr()
// Get the actual node IDs (not peer IDs)
localNodeID := GetNodeId()
remoteNodeID := remotePeer.String() // TODO: Implement proper node ID mapping
// Create disconnection event
event := &TopologyEdgeDeleted{
EventType: EventTypeTopologyEdgeDeleted,
EventID: uuid.New().String(),
Edge: Connection{
LocalNodeID: localNodeID,
SendBackNodeID: remoteNodeID,
LocalMultiaddr: parseMultiaddr(localAddr),
SendBackMultiaddr: parseMultiaddr(remoteAddr),
ConnectionProfile: nil,
},
}
// Write event to database
if err := writeEvent(EventTypeTopologyEdgeDeleted, event); err != nil {
log.Printf("Failed to write edge deleted event: %v", err)
} else {
log.Printf("Wrote edge deleted event: %s -> %s", localNodeID, remoteNodeID)
}
}
// OpenedStream is called when a stream is opened
func (n *NotifeeHandler) OpenedStream(net network.Network, str network.Stream) {}
// ClosedStream is called when a stream is closed
func (n *NotifeeHandler) ClosedStream(net network.Network, str network.Stream) {}
// parseMultiaddr converts a libp2p multiaddr to our Multiaddr struct
func parseMultiaddr(ma multiaddr.Multiaddr) Multiaddr {
result := Multiaddr{
Address: ma.String(),
}
// Extract IPv4 address if present
if ipStr, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil {
result.IPv4Address = ipStr
}
// Extract port if present
if portStr, err := ma.ValueForProtocol(multiaddr.P_TCP); err == nil {
if port, err := strconv.Atoi(portStr); err == nil {
result.Port = port
}
}
return result
}
// GetNotifee returns a singleton instance of the notifee handler
func GetNotifee() network.Notifiee {
return &NotifeeHandler{}
}

View File

@@ -6,6 +6,10 @@ import (
"crypto/sha256"
"encoding/json"
"log"
"net"
"os"
"sort"
"strings"
"sync"
"time"
@@ -15,9 +19,11 @@ import (
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/pnet"
mdns "github.com/libp2p/go-libp2p/p2p/discovery/mdns"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/multiformats/go-multiaddr"
)
var node host.Host
@@ -28,22 +34,337 @@ var mu sync.Mutex
var refCount int
var topicsMap = make(map[string]*pubsub.Topic)
// Connection retry state tracking
type peerConnState struct {
retryCount int
lastAttempt time.Time
}
var peerLastAddrs = make(map[peer.ID][]multiaddr.Multiaddr)
var addrsMu sync.Mutex
var connecting = make(map[peer.ID]bool)
var connMu sync.Mutex
var peerRetryState = make(map[peer.ID]*peerConnState)
var retryMu sync.Mutex
const (
maxRetries = 5 // Increased for more tolerance to transient failures
initialBackoff = 2 * time.Second
maxBackoff = 33 * time.Second
retryResetTime = 1 * time.Minute // Reduced for faster recovery after max retries
)
type discoveryNotifee struct {
h host.Host
}
// sortAddrs returns a sorted copy of addresses for comparison
func sortAddrs(addrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
s := make([]multiaddr.Multiaddr, len(addrs))
copy(s, addrs)
sort.Slice(s, func(i, j int) bool {
return s[i].String() < s[j].String()
})
return s
}
// addrsChanged checks if two address sets differ
func addrsChanged(a, b []multiaddr.Multiaddr) bool {
if len(a) != len(b) {
return true
}
sa := sortAddrs(a)
sb := sortAddrs(b)
for i := range sa {
if !sa[i].Equal(sb[i]) {
return true
}
}
return false
}
// isAddressValid checks if an address should be used for connections
func isAddressValid(addr multiaddr.Multiaddr) bool {
// Allow loopback for testing if env var is set
allowLoopback := os.Getenv("FORWARDER_ALLOW_LOOPBACK") == "true"
// Check IPv4 addresses
ipStr, err := addr.ValueForProtocol(multiaddr.P_IP4)
if err == nil && ipStr != "" {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// Filter out loopback, unspecified addresses (unless testing)
if !allowLoopback && (ip.IsLoopback() || ip.IsUnspecified()) {
return false
}
if ip.IsUnspecified() {
return false
}
// Filter out common VPN ranges (Tailscale uses 100.64.0.0/10)
if ip.To4() != nil && ip.To4()[0] == 100 && ip.To4()[1] >= 64 && ip.To4()[1] <= 127 {
return false
}
}
// Check IPv6 addresses
ipStr, err = addr.ValueForProtocol(multiaddr.P_IP6)
if err == nil && ipStr != "" {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// Filter out loopback, unspecified addresses (unless testing)
if !allowLoopback && (ip.IsLoopback() || ip.IsUnspecified()) {
return false
}
if ip.IsUnspecified() {
return false
}
// Filter out Tailscale IPv6 (fd7a:115c:a1e0::/48)
if strings.HasPrefix(strings.ToLower(ipStr), "fd7a:115c:a1e0:") {
return false
}
}
return true
}
// customInterfaceAddresses returns IPs only from interfaces that are up and running (has link)
func customInterfaceAddresses() ([]net.IP, error) {
var ips []net.IP
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, ifi := range ifaces {
if ifi.Flags&net.FlagUp == 0 || ifi.Flags&net.FlagRunning == 0 {
continue
}
addrs, err := ifi.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP != nil {
ips = append(ips, ipnet.IP)
}
}
}
return ips, nil
}
// customAddrsFactory expands wildcard listen addrs to actual IPs on up+running interfaces, then filters
func customAddrsFactory(listenAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
ips, err := customInterfaceAddresses()
if err != nil {
log.Printf("Error getting interface IPs: %v", err)
return nil
}
var advAddrs []multiaddr.Multiaddr
for _, la := range listenAddrs {
comps := multiaddr.Split(la)
if len(comps) == 0 {
continue
}
first := comps[0]
protos := first.Protocols()
if len(protos) == 0 {
continue
}
code := protos[0].Code
val, err := first.ValueForProtocol(code)
var isWildcard bool
if err == nil && ((code == multiaddr.P_IP4 && val == "0.0.0.0") || (code == multiaddr.P_IP6 && val == "::")) {
isWildcard = true
}
if isWildcard {
// Expand to each valid IP
for _, ip := range ips {
var pcodeStr string
if ip.To4() != nil {
pcodeStr = "4"
} else {
pcodeStr = "6"
}
newIPStr := "/ip" + pcodeStr + "/" + ip.String()
newIPMA, err := multiaddr.NewMultiaddr(newIPStr)
if err != nil {
continue
}
var newComps []multiaddr.Multiaddrer
newComps = append(newComps, newIPMA)
for _, c := range comps[1:] {
newComps = append(newComps, c.Multiaddr())
}
newa := multiaddr.Join(newComps...)
if isAddressValid(newa) {
advAddrs = append(advAddrs, newa)
}
}
} else if isAddressValid(la) {
advAddrs = append(advAddrs, la)
}
}
return advAddrs
}
func (n *discoveryNotifee) HandlePeerFound(pi peer.AddrInfo) {
if n.h.ID() >= pi.ID {
return
}
log.Printf("mDNS discovered peer %s with %d addresses", pi.ID, len(pi.Addrs))
// Check if already connected first
if n.h.Network().Connectedness(pi.ID) == network.Connected {
log.Printf("Already connected to peer %s", pi.ID)
return
}
ctx := context.Background()
// Clear any existing addresses for this peer to ensure we use only fresh ones from mDNS
ps := n.h.Peerstore()
ps.ClearAddrs(pi.ID)
log.Printf("Cleared old addresses for peer %s", pi.ID)
// During normal operation, only higher ID connects to avoid double connections
// But if we have retry state for this peer, both sides should attempt
// Also, if we have no connections at all, both sides should attempt
retryMu.Lock()
_, hasRetryState := peerRetryState[pi.ID]
retryMu.Unlock()
// Check if we should skip based on ID comparison
// Skip only if we have a higher ID, no retry state, and we already have connections
if n.h.ID() >= pi.ID && !hasRetryState && len(n.h.Network().Peers()) > 0 {
log.Printf("Skipping initial connection to peer %s (lower ID)", pi.ID)
return
}
// Filter addresses before attempting connection
var filteredAddrs []multiaddr.Multiaddr
for _, addr := range pi.Addrs {
if isAddressValid(addr) {
filteredAddrs = append(filteredAddrs, addr)
log.Printf("Valid address for %s: %s", pi.ID, addr)
} else {
log.Printf("Filtered out address for %s: %s", pi.ID, addr)
}
}
if len(filteredAddrs) == 0 {
log.Printf("No valid addresses for peer %s after filtering, skipping connection attempt", pi.ID)
return
}
// Check for address changes and reset retries if changed
addrsMu.Lock()
lastAddrs := peerLastAddrs[pi.ID]
addrsMu.Unlock()
if addrsChanged(lastAddrs, filteredAddrs) {
log.Printf("Detected address change for peer %s, resetting retry count", pi.ID)
retryMu.Lock()
if state, ok := peerRetryState[pi.ID]; ok {
state.retryCount = 0
}
retryMu.Unlock()
// Update last known addresses
addrsMu.Lock()
peerLastAddrs[pi.ID] = append([]multiaddr.Multiaddr(nil), filteredAddrs...) // Copy
addrsMu.Unlock()
}
pi.Addrs = filteredAddrs
// Add the filtered addresses to the peerstore with a reasonable TTL
ps.AddAddrs(pi.ID, filteredAddrs, peerstore.TempAddrTTL)
// Attempt connection with retry logic
go n.connectWithRetry(pi)
}
func (n *discoveryNotifee) connectWithRetry(pi peer.AddrInfo) {
// Serialize connection attempts per peer
connMu.Lock()
if connecting[pi.ID] {
connMu.Unlock()
log.Printf("Already connecting to peer %s, skipping duplicate attempt", pi.ID)
return
}
connecting[pi.ID] = true
connMu.Unlock()
defer func() {
connMu.Lock()
delete(connecting, pi.ID)
connMu.Unlock()
}()
retryMu.Lock()
state, exists := peerRetryState[pi.ID]
if !exists {
state = &peerConnState{}
peerRetryState[pi.ID] = state
}
// Check if we've exceeded max retries
if state.retryCount >= maxRetries {
// Check if enough time has passed to reset retry count
if time.Since(state.lastAttempt) > retryResetTime {
state.retryCount = 0
log.Printf("Reset retry count for peer %s due to time elapsed", pi.ID)
} else {
retryMu.Unlock()
log.Printf("Max retries reached for peer %s, skipping", pi.ID)
return
}
}
// Calculate backoff duration
backoffDuration := time.Duration(1<<uint(state.retryCount)) * initialBackoff
if backoffDuration > maxBackoff {
backoffDuration = maxBackoff
}
// Check if we need to wait before retrying
if state.retryCount > 0 && time.Since(state.lastAttempt) < backoffDuration {
retryMu.Unlock()
log.Printf("Backoff active for peer %s, skipping attempt", pi.ID)
return
}
state.lastAttempt = time.Now()
retryMu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := n.h.Connect(ctx, pi); err != nil {
log.Printf("Failed to connect to %s: %v", pi.ID.String(), err)
log.Printf("Failed to connect to %s (attempt %d/%d): %v", pi.ID, state.retryCount+1, maxRetries, err)
retryMu.Lock()
state.retryCount++
retryMu.Unlock()
// Schedule retry if we haven't exceeded max attempts
if state.retryCount < maxRetries {
time.AfterFunc(backoffDuration, func() {
// Check if we're still not connected before retrying
if n.h.Network().Connectedness(pi.ID) != network.Connected {
n.connectWithRetry(pi)
}
})
}
} else {
log.Printf("Connected to %s", pi.ID.String())
log.Printf("Successfully connected to %s", pi.ID)
// Reset retry state on successful connection
retryMu.Lock()
delete(peerRetryState, pi.ID)
retryMu.Unlock()
addrsMu.Lock()
delete(peerLastAddrs, pi.ID)
addrsMu.Unlock()
log.Printf("Cleared last addresses for disconnected peer %s", pi.ID)
}
}
@@ -76,6 +397,9 @@ func getNode(ctx context.Context) {
opts = append(opts, libp2p.EnableHolePunching()) // Better NAT traversal
opts = append(opts, libp2p.EnableRelay()) // Allow relaying
// Custom address factory to avoid advertising down interfaces
opts = append(opts, libp2p.AddrsFactory(customAddrsFactory))
node, err = libp2p.New(opts...)
if err != nil {
log.Fatalf("failed to create host: %v", err)
@@ -103,9 +427,118 @@ func getNode(ctx context.Context) {
node.Close()
log.Fatalf("failed to start mdns service: %v", err)
}
// Register disconnect notifiee to clear stale addresses
node.Network().Notify(&disconnectNotifee{})
// Register event notifiee to track topology changes
node.Network().Notify(GetNotifee())
// Start a goroutine to periodically trigger mDNS discovery
go periodicMDNSDiscovery()
})
}
// periodicMDNSDiscovery ensures mDNS continues to work after network changes
func periodicMDNSDiscovery() {
// Start with faster checks, then slow down
fastCheckDuration := 5 * time.Second
slowCheckDuration := 30 * time.Second
currentDuration := fastCheckDuration
noConnectionCount := 0
ticker := time.NewTicker(currentDuration)
defer ticker.Stop()
for range ticker.C {
if mdnsSer == nil || node == nil {
return
}
// Log current connection status
peers := node.Network().Peers()
if len(peers) == 0 {
noConnectionCount++
log.Printf("No connected peers (check #%d), mDNS service running: %v", noConnectionCount, mdnsSer != nil)
// Force mDNS to re-announce when we have no peers
// This helps recovery after network interface changes
if noConnectionCount > 1 { // Skip first check to avoid unnecessary restart
forceRestartMDNS()
}
// Keep fast checking when disconnected
if currentDuration != fastCheckDuration {
currentDuration = fastCheckDuration
ticker.Reset(currentDuration)
log.Printf("Switching to fast mDNS checks (every %v)", currentDuration)
}
} else {
log.Printf("Currently connected to %d peers", len(peers))
noConnectionCount = 0
// Switch to slow checking when connected
if currentDuration != slowCheckDuration {
currentDuration = slowCheckDuration
ticker.Reset(currentDuration)
log.Printf("Switching to slow mDNS checks (every %v)", currentDuration)
}
}
}
}
// forceRestartMDNS restarts the mDNS service to force re-announcement
func forceRestartMDNS() {
mu.Lock()
defer mu.Unlock()
if mdnsSer != nil && node != nil {
log.Printf("Force restarting mDNS service for re-announcement")
oldMdns := mdnsSer
rendezvous := "forwarder_network"
notifee := &discoveryNotifee{h: node}
newMdns := mdns.NewMdnsService(node, rendezvous, notifee)
if err := newMdns.Start(); err != nil {
log.Printf("Failed to restart mDNS service: %v", err)
} else {
oldMdns.Close()
mdnsSer = newMdns
log.Printf("Successfully restarted mDNS service")
}
}
}
// disconnectNotifee clears stale peer addresses on disconnect
type disconnectNotifee struct{}
func (d *disconnectNotifee) Connected(network.Network, network.Conn) {}
func (d *disconnectNotifee) Disconnected(n network.Network, c network.Conn) {
p := c.RemotePeer()
ps := n.Peerstore()
// Clear all addresses from peerstore to force fresh discovery on reconnect
ps.ClearAddrs(p)
// Also clear retry state for this peer
retryMu.Lock()
delete(peerRetryState, p)
retryMu.Unlock()
log.Printf("Cleared stale addresses and retry state for disconnected peer %s", p)
// Try to restart mDNS discovery after a short delay to handle network interface changes
go func() {
time.Sleep(2 * time.Second)
log.Printf("Triggering mDNS re-discovery after disconnect")
forceRestartMDNS()
}()
}
func (d *disconnectNotifee) OpenedStream(network.Network, network.Stream) {}
func (d *disconnectNotifee) ClosedStream(network.Network, network.Stream) {}
func (d *disconnectNotifee) Listen(network.Network, multiaddr.Multiaddr) {}
func (d *disconnectNotifee) ListenClose(network.Network, multiaddr.Multiaddr) {}
type libP2PConnector struct {
topic string
sub *pubsub.Subscription

View File

@@ -114,3 +114,7 @@ extend-select = ["I", "N", "B", "A", "PIE", "SIM"]
[tool.pytest.ini_options]
pythonpath = "."
asyncio_mode = "auto"
markers = [
"slow: marks tests as slow (deselected by default)"
]
addopts = "-m 'not slow'"

4
run.sh
View File

@@ -40,7 +40,7 @@ fi
# Second command (master) - changes based on replica flag
if [ "$REPLICA" = true ]; then
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export EXO_RUN_AS_REPLICA=1 EXO_HOME=.exo_replica API_PORT=8001; uv run -m master.main'\""
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export RUST_LOG=true EXO_RUN_AS_REPLICA=1 EXO_HOME=.exo_replica API_PORT=8001; uv run -m master.main'\""
else
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c uv run -m master.main\""
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export RUST_LOG=true; uv run -m master.main'\""
fi

View File

@@ -200,7 +200,7 @@ fn mdns_behaviour(keypair: &identity::Keypair) -> AnyResult<mdns::tokio::Behavio
// mDNS config => enable IPv6
let mdns_config = Config {
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
enable_ipv6: true,
..Default::default()
};

View File

@@ -17,6 +17,7 @@
use crate::behaviour::{discovery_behaviour, DiscoveryBehaviour};
use crate::transport::discovery_transport;
use libp2p::{identity, Swarm, SwarmBuilder};
use std::net::IpAddr;
pub mod behaviour;
pub mod transport;
@@ -49,11 +50,18 @@ pub fn discovery_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm<Dis
.with_behaviour(discovery_behaviour)?
.build();
// Listen on all interfaces and whatever port the OS assigns
// swarm.listen_on("/ip4/0.0.0.0/udp/0/quic-v1".parse()?)?; // TODO: make this
let listen_addr = "/ip4/0.0.0.0/tcp/0".parse()?;
log::info!("RUST: Attempting to listen on: {}", listen_addr);
swarm.listen_on(listen_addr)?;
// Listen on IPv4
let listen_addr_ipv4 = "/ip4/0.0.0.0/tcp/0".parse()?;
log::info!("RUST: Attempting to listen on: {}", listen_addr_ipv4);
swarm.listen_on(listen_addr_ipv4)?;
// Listen on IPv6 - try but don't fail if not available
let listen_addr_ipv6 = "/ip6/::/tcp/0".parse()?;
log::info!("RUST: Attempting to listen on: {}", listen_addr_ipv6);
match swarm.listen_on(listen_addr_ipv6) {
Ok(_) => log::info!("RUST: Successfully listening on IPv6"),
Err(e) => log::warn!("RUST: Failed to listen on IPv6 (this is okay if IPv6 is not available): {:?}", e),
}
Ok(swarm)
}

View File

@@ -33,7 +33,8 @@ fn tcp_transport(
};
// `TCP_NODELAY` enabled => avoid latency
let tcp_config = Config::default().nodelay(true);
let tcp_config = Config::default()
.nodelay(true);
// V1 + lazy flushing => 0-RTT negotiation
let upgrade_version = Version::V1Lazy;

View File

@@ -18,12 +18,14 @@ use libp2p::multiaddr::multiaddr;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{ConnectionId, SwarmEvent, ToSwarm};
use libp2p::{Multiaddr, PeerId, Swarm, gossipsub, mdns};
use std::net::IpAddr;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{Bound, Py, PyObject, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::convert::identity;
use std::error::Error;
use tokio::sync::mpsc;
use tokio::time::{interval, Duration};
struct ConnectionUpdate {
/// Identity of the peer that we have connected to.
@@ -77,6 +79,46 @@ enum IncomingDiscoveryMessage {
AddDisconnectedCallback(Box<dyn alias::SendFn<(ConnectionUpdate,), ()>>),
}
/// Check if a multiaddr is valid for connection
fn is_address_valid(addr: &Multiaddr) -> bool {
use libp2p::multiaddr::Protocol;
for component in addr.iter() {
match component {
Protocol::Ip4(ip) => {
let ip_addr = IpAddr::V4(ip);
// Filter out loopback and unspecified addresses
if ip_addr.is_loopback() || ip_addr.is_unspecified() {
return false;
}
// Filter out Tailscale ranges (100.64.0.0/10)
if let IpAddr::V4(ipv4) = ip_addr {
let octets = ipv4.octets();
if octets[0] == 100 && octets[1] >= 64 && octets[1] <= 127 {
return false;
}
}
}
Protocol::Ip6(ip) => {
let ip_addr = IpAddr::V6(ip);
// Filter out loopback and unspecified addresses
if ip_addr.is_loopback() || ip_addr.is_unspecified() {
return false;
}
// Filter out Tailscale IPv6 (fd7a:115c:a1e0::/48)
if let IpAddr::V6(ipv6) = ip_addr {
let segments = ipv6.segments();
if segments[0] == 0xfd7a && segments[1] == 0x115c && segments[2] == 0xa1e0 {
return false;
}
}
}
_ => {}
}
}
true
}
#[allow(clippy::enum_glob_use)]
async fn discovery_task(
mut receiver: mpsc::Receiver<IncomingDiscoveryMessage>,
@@ -94,8 +136,59 @@ async fn discovery_task(
let mut connected_callbacks: Vec<Box<dyn alias::SendFn<(ConnectionUpdate,), ()>>> = vec![];
let mut disconnected_callbacks: Vec<Box<dyn alias::SendFn<(ConnectionUpdate,), ()>>> = vec![];
// Create periodic health check timer with adaptive interval
let fast_check_duration = Duration::from_secs(5);
let slow_check_duration = Duration::from_secs(30);
let mut health_check_interval = interval(fast_check_duration);
let mut no_connection_count = 0;
loop {
tokio::select! {
_ = health_check_interval.tick() => {
// Check connection health periodically
let connected_peers = swarm.connected_peers().count();
if connected_peers == 0 {
no_connection_count += 1;
log::info!("RUST: No connected peers (check #{no_connection_count})");
// Keep fast checking when disconnected
if health_check_interval.period() != fast_check_duration {
health_check_interval = interval(fast_check_duration);
log::info!("RUST: Switching to fast health checks (every {:?})", fast_check_duration);
}
// Force mDNS restart after multiple failed checks
if no_connection_count > 1 { // Trigger faster, after 2 checks
log::info!("RUST: Attempting to restart mDNS discovery");
// Note: In rust-libp2p, we can't easily restart mDNS like in Go,
// but we can force a re-announce by changing listening addresses
// This is a workaround to trigger mDNS to re-announce
// Try listening on a new ephemeral port to force re-announcement
match swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse().unwrap()) {
Ok(_) => log::info!("RUST: Added new listener to force mDNS re-announcement"),
Err(e) => log::error!("RUST: Failed to add new listener: {e:?}"),
}
// Also try IPv6
match swarm.listen_on("/ip6/::/tcp/0".parse().unwrap()) {
Ok(_) => log::info!("RUST: Added IPv6 listener to force mDNS re-announcement"),
Err(e) => log::error!("RUST: Failed to add IPv6 listener: {e:?}"),
}
}
} else {
if no_connection_count > 0 {
log::info!("RUST: Connection restored, currently connected to {connected_peers} peers");
}
no_connection_count = 0;
// Switch to slow checking when connected
if health_check_interval.period() != slow_check_duration {
health_check_interval = interval(slow_check_duration);
log::info!("RUST: Switching to slow health checks (every {:?})", slow_check_duration);
}
}
}
message = receiver.recv() => {
// handle closed channel
let Some(message) = message else {
@@ -120,6 +213,13 @@ async fn discovery_task(
Behaviour(Mdns(Discovered(list))) => {
for (peer_id, multiaddr) in list {
log::info!("RUST: mDNS discovered a new peer: {peer_id} on {multiaddr}");
// Filter out invalid addresses
if !is_address_valid(&multiaddr) {
log::info!("RUST: Filtered out invalid address: {multiaddr}");
continue;
}
let local_peer_id = *swarm.local_peer_id();
// To avoid simultaneous dial races, only the lexicographically larger peer_id dials.
if peer_id > local_peer_id {
@@ -234,12 +334,36 @@ async fn discovery_task(
send_back_addr: send_back_addr.clone(),
});
}
// If this was the last connection to the peer, try to force mDNS re-discovery
if num_established == 0 {
log::info!("RUST: Last connection to peer {peer_id} closed, triggering mDNS re-discovery");
// Remove from gossipsub to ensure clean state
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
// Force a listen address change to trigger mDNS re-announcement
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(2)).await;
log::info!("RUST: Delayed mDNS trigger after disconnect");
});
}
}
NewListenAddr { address, .. } => {
log::info!("RUST: Local node is listening on {address}");
let local_peer = swarm.local_peer_id();
log::info!("RUST: Local peer_id: {local_peer}");
}
OutgoingConnectionError { peer_id, error, .. } => {
log::error!("RUST: Outgoing connection error to peer {peer_id:?}: {error:?}");
// Connection failed, might be due to network change
if let Some(peer) = peer_id {
// Remove from gossipsub to allow fresh connection attempts
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer);
}
}
IncomingConnectionError { send_back_addr, error, .. } => {
log::error!("RUST: Incoming connection error from {send_back_addr}: {error:?}");
}
e => {
log::debug!("RUST: Other event {e:?}");
}

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
import copy
from functools import singledispatch
from typing import Mapping, TypeVar
from typing import Mapping
# from shared.topology import Topology
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
Event,
EventFromEventLog,
Heartbeat,
InstanceActivated,
InstanceCreated,
InstanceDeactivated,
@@ -35,20 +34,25 @@ from shared.types.worker.common import NodeStatus, RunnerId
from shared.types.worker.instances import Instance, InstanceId, InstanceStatus
from shared.types.worker.runners import RunnerStatus
S = TypeVar("S", bound=State)
@singledispatch
def event_apply(event: Event, state: State) -> State:
"""Apply an event to *state*.
Events decorated with ``@no_op_event`` set ``__no_apply__ = True`` on the
class. Such events are considered *no-ops* and therefore leave the state
unchanged without requiring a dedicated handler in this dispatch table.
"""
if getattr(event, "__no_apply__", False):
return state
raise RuntimeError(f"no handler registered for event type {type(event).__name__}")
def apply(state: State, event: EventFromEventLog[Event]) -> State:
new_state: State = event_apply(event.event, state)
return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log})
@event_apply.register(Heartbeat)
def apply_heartbeat(event: Heartbeat, state: State) -> State:
return state
@event_apply.register(TaskCreated)
def apply_task_created(event: TaskCreated, state: State) -> State:
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}
@@ -148,10 +152,6 @@ def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> Sta
new_node_status: Mapping[NodeId, NodeStatus] = {**state.node_status, event.node_id: event.node_state}
return state.model_copy(update={"node_status": new_node_status})
@event_apply.register(ChunkGenerated)
def apply_chunk_generated(event: ChunkGenerated, state: State) -> State:
return state
@event_apply.register(TopologyNodeCreated)
def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
@@ -164,6 +164,13 @@ def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> Sta
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
opposite_edge = Connection(
local_node_id=event.edge.send_back_node_id,
send_back_node_id=event.edge.local_node_id,
local_multiaddr=event.edge.send_back_multiaddr,
send_back_multiaddr=event.edge.local_multiaddr
)
topology.add_connection(opposite_edge)
return state.model_copy(update={"topology": topology})
@event_apply.register(TopologyEdgeReplacedAtomically)

View File

@@ -1,6 +1,7 @@
import asyncio
import contextlib
import json
import random
from asyncio import Queue, Task
from collections.abc import Sequence
from logging import Logger, getLogger
@@ -8,8 +9,8 @@ from pathlib import Path
from typing import Any, cast
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlmodel import SQLModel
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, create_async_engine
from shared.types.events import Event, EventParser, NodeId
from shared.types.events._events import Heartbeat
@@ -81,7 +82,8 @@ class AsyncSQLiteEventStorage:
async def get_events_since(
self,
last_idx: int
last_idx: int,
ignore_no_op_events: bool = False
) -> Sequence[EventFromEventLog[Event]]:
"""Retrieve events after a specific index."""
if self._closed:
@@ -107,8 +109,11 @@ class AsyncSQLiteEventStorage:
event_data: dict[str, Any] = cast(dict[str, Any], json.loads(raw_event_data))
else:
event_data = cast(dict[str, Any], raw_event_data)
event = EventParser.validate_python(event_data)
if ignore_no_op_events and event.__no_apply__:
continue
events.append(EventFromEventLog(
event=EventParser.validate_python(event_data),
event=event,
origin=NodeId(origin),
idx_in_log=rowid # rowid becomes idx_in_log
))
@@ -169,17 +174,65 @@ class AsyncSQLiteEventStorage:
echo=False,
connect_args={
"check_same_thread": False,
}
"timeout": 30.0, # Connection timeout in seconds
},
pool_pre_ping=True, # Test connections before using them
pool_size=5,
max_overflow=10
)
# Create tables using SQLModel
# Create tables with proper race condition handling
async with self._engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
# First check if the table exists using SQLite's master table
result = await conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' AND name='events'")
)
table_exists = result.fetchone() is not None
# Enable WAL mode and other optimizations
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=10000"))
if not table_exists:
try:
# Use CREATE TABLE IF NOT EXISTS as a more atomic operation
# This avoids race conditions between check and create
await conn.execute(text("""
CREATE TABLE IF NOT EXISTS events (
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
origin TEXT NOT NULL,
event_type TEXT NOT NULL,
event_id TEXT NOT NULL,
event_data TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"""))
# Create indexes if they don't exist
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_origin ON events(origin)"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_event_type ON events(event_type)"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_event_id ON events(event_id)"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_created_at ON events(created_at)"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_events_origin_created ON events(origin, created_at)"))
self._logger.info("Events table and indexes created successfully")
except OperationalError as e:
# Even with IF NOT EXISTS, log any unexpected errors
self._logger.error(f"Error creating table: {e}")
# Re-check if table exists now
result = await conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' AND name='events'")
)
if result.fetchone() is None:
raise RuntimeError(f"Failed to create events table: {e}") from e
else:
self._logger.info("Events table exists (likely created by another process)")
else:
self._logger.debug("Events table already exists")
# Enable WAL mode and other optimizations with retry logic
await self._execute_pragma_with_retry(conn, [
"PRAGMA journal_mode=WAL",
"PRAGMA synchronous=NORMAL",
"PRAGMA cache_size=10000",
"PRAGMA busy_timeout=30000" # 30 seconds busy timeout
])
async def _batch_writer(self) -> None:
"""Background task that drains the queue and commits batches.
@@ -250,6 +303,69 @@ class AsyncSQLiteEventStorage:
if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0:
self._logger.debug(f"Committed batch of {len(batch)} events")
except OperationalError as e:
if "database is locked" in str(e):
self._logger.warning(f"Database locked during batch commit, will retry: {e}")
# Retry with exponential backoff
await self._commit_batch_with_retry(batch)
else:
self._logger.error(f"Failed to commit batch: {e}")
raise
except Exception as e:
self._logger.error(f"Failed to commit batch: {e}")
raise
async def _execute_pragma_with_retry(self, conn: AsyncConnection, pragmas: list[str], max_retries: int = 5) -> None:
"""Execute PRAGMA statements with retry logic for database lock errors."""
for pragma in pragmas:
retry_count = 0
base_delay: float = 0.1 # 100ms
while retry_count < max_retries:
try:
await conn.execute(text(pragma))
break
except OperationalError as e:
if "database is locked" in str(e) and retry_count < max_retries - 1:
delay = cast(float, base_delay * (2 ** retry_count) + random.uniform(0, 0.1))
self._logger.warning(f"Database locked on '{pragma}', retry {retry_count + 1}/{max_retries} after {delay:.2f}s")
await asyncio.sleep(delay)
retry_count += 1
else:
self._logger.error(f"Failed to execute '{pragma}' after {retry_count + 1} attempts: {e}")
raise
async def _commit_batch_with_retry(self, batch: list[tuple[Event, NodeId]], max_retries: int = 5) -> None:
"""Commit a batch with retry logic for database lock errors."""
retry_count = 0
base_delay: float = 0.1 # 100ms
while retry_count < max_retries:
try:
assert self._engine is not None
async with AsyncSession(self._engine) as session:
for event, origin in batch:
stored_event = StoredEvent(
origin=origin,
event_type=event.event_type,
event_id=str(event.event_id),
event_data=event.model_dump(mode='json')
)
session.add(stored_event)
await session.commit()
if len([ev for ev in batch if not isinstance(ev[0], Heartbeat)]) > 0:
self._logger.debug(f"Committed batch of {len(batch)} events after {retry_count} retries")
return
except OperationalError as e:
if "database is locked" in str(e) and retry_count < max_retries - 1:
delay = cast(float, base_delay * (2 ** retry_count) + random.uniform(0, 0.1))
self._logger.warning(f"Database locked on batch commit, retry {retry_count + 1}/{max_retries} after {delay:.2f}s")
await asyncio.sleep(delay)
retry_count += 1
else:
self._logger.error(f"Failed to commit batch after {retry_count + 1} attempts: {e}")
raise

View File

@@ -1,5 +1,8 @@
import asyncio
from logging import Logger
from typing import Dict
from typing import Dict, Optional, cast
from sqlalchemy.exc import OperationalError
from shared.constants import EXO_HOME
from shared.db.sqlite.config import EventLogConfig, EventLogType
@@ -25,11 +28,34 @@ class EventLogManager:
EXO_HOME.mkdir(parents=True, exist_ok=True)
# TODO: This seems like it's a pattern to avoid an async __init__ function. But as we know, there's a better pattern for this - using a create() function, like in runner_supervisor.
async def initialize(self) -> None:
"""Initialize both connectors - call this during startup"""
async def initialize(self, max_retries: int = 3) -> None:
"""Initialize both connectors with retry logic - call this during startup"""
# Both master and worker need both connectors
await self.get_connector(EventLogType.WORKER_EVENTS)
await self.get_connector(EventLogType.GLOBAL_EVENTS)
for log_type in [EventLogType.WORKER_EVENTS, EventLogType.GLOBAL_EVENTS]:
retry_count: int = 0
last_error: Optional[Exception] = None
while retry_count < max_retries:
try:
await self.get_connector(log_type)
break
except OperationalError as e:
last_error = e
if "database is locked" in str(e) and retry_count < max_retries - 1:
retry_count += 1
delay = cast(float, 0.5 * (2 ** retry_count))
self._logger.warning(f"Database locked while initializing {log_type.value}, retry {retry_count}/{max_retries} after {delay}s")
await asyncio.sleep(delay)
else:
self._logger.error(f"Failed to initialize {log_type.value} after {retry_count + 1} attempts: {e}")
raise RuntimeError(f"Could not initialize {log_type.value} database after {retry_count + 1} attempts") from e
except Exception as e:
self._logger.error(f"Unexpected error initializing {log_type.value}: {e}")
raise
if retry_count >= max_retries and last_error:
raise RuntimeError(f"Could not initialize {log_type.value} database after {max_retries} attempts") from last_error
self._logger.info("Initialized all event log connectors")
async def get_connector(self, log_type: EventLogType) -> AsyncSQLiteEventStorage:
@@ -37,20 +63,24 @@ class EventLogManager:
if log_type not in self._connectors:
db_path = self._config.get_db_path(log_type)
connector = AsyncSQLiteEventStorage(
db_path=db_path,
batch_size=self._config.batch_size,
batch_timeout_ms=self._config.batch_timeout_ms,
debounce_ms=self._config.debounce_ms,
max_age_ms=self._config.max_age_ms,
logger=self._logger
)
try:
connector = AsyncSQLiteEventStorage(
db_path=db_path,
batch_size=self._config.batch_size,
batch_timeout_ms=self._config.batch_timeout_ms,
debounce_ms=self._config.debounce_ms,
max_age_ms=self._config.max_age_ms,
logger=self._logger
)
# Start the connector (creates tables if needed)
await connector.start()
# Start the connector (creates tables if needed)
await connector.start()
self._connectors[log_type] = connector
self._logger.info(f"Initialized {log_type.value} connector at {db_path}")
self._connectors[log_type] = connector
self._logger.info(f"Initialized {log_type.value} connector at {db_path}")
except Exception as e:
self._logger.error(f"Failed to create {log_type.value} connector: {e}")
raise
return self._connectors[log_type]

View File

@@ -86,8 +86,11 @@ class Topology(TopologyProto):
yield connection
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
try:
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def get_node_multiaddr(self, node_id: NodeId) -> Multiaddr:
for connection in self.list_connections():
@@ -106,8 +109,11 @@ class Topology(TopologyProto):
self._graph.update_edge_by_index(rx_idx, connection)
def get_connection_profile(self, connection: Connection) -> ConnectionProfile | None:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
@@ -118,28 +124,23 @@ class Topology(TopologyProto):
def remove_connection(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
print(f"removing connection: {connection}, is bridge: {self._is_bridge(connection)}")
if self._is_bridge(connection):
# Determine the reference node from which reachability is calculated.
# Prefer a master node if the topology knows one; otherwise fall back to
# the local end of the connection being removed.
reference_node_id: NodeId = self.master_node_id if self.master_node_id is not None else connection.local_node_id
orphan_node_ids = self._get_orphan_node_ids(reference_node_id, connection)
print(f"orphan node ids: {orphan_node_ids}")
for orphan_node_id in orphan_node_ids:
orphan_node_rx_id = self._node_id_to_rx_id_map[orphan_node_id]
print(f"removing orphan node: {orphan_node_id}, rx_id: {orphan_node_rx_id}")
self._graph.remove_node(orphan_node_rx_id)
del self._node_id_to_rx_id_map[orphan_node_id]
del self._rx_id_to_node_id_map[orphan_node_rx_id]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
if rx_idx in self._rx_id_to_node_id_map:
del self._rx_id_to_node_id_map[rx_idx]
print(f"topology after edge removal: {self.to_snapshot()}")
def get_cycles(self) -> list[list[Node]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[Node]] = []
@@ -161,14 +162,12 @@ class Topology(TopologyProto):
return topology
def _is_bridge(self, connection: Connection) -> bool:
edge_idx = self._edge_id_to_rx_id_map[connection]
graph_copy: rx.PyDiGraph[Node, Connection] = self._graph.copy()
components_before = rx.strongly_connected_components(graph_copy)
"""Check if removing this connection will orphan any nodes from the master."""
if self.master_node_id is None:
return False
graph_copy.remove_edge_from_index(edge_idx)
components_after = rx.strongly_connected_components(graph_copy)
return components_after > components_before
orphan_node_ids = self._get_orphan_node_ids(self.master_node_id, connection)
return len(orphan_node_ids) > 0
def _get_orphan_node_ids(self, master_node_id: NodeId, connection: Connection) -> list[NodeId]:
"""Return node_ids that become unreachable from `master_node_id` once `connection` is removed.

View File

@@ -3,7 +3,9 @@ from enum import Enum
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Literal,
TypeVar,
Union,
get_args,
get_origin,
@@ -90,6 +92,7 @@ class _BaseEvent[T: _EventType](BaseModel):
event_type: T
event_id: EventId = EventId()
__no_apply__: bool = False
def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool:
"""Check if the event was sent by the correct node.
@@ -99,6 +102,20 @@ class _BaseEvent[T: _EventType](BaseModel):
"""
return True
_E = TypeVar("_E", bound=_BaseEvent[Any])
def no_op_event(cls: type[_E]) -> type[_E]:
"""Decorator to mark an event class as a *no-op*.
Events marked as no-ops do not require an `event_apply` registration the
apply layer will simply return the current state unchanged. This reduces
boilerplate and keeps console output quieter for high-frequency events
such as *Heartbeat* or streaming *ChunkGenerated* messages.
"""
cls.__no_apply__ = True # Used by the apply layer to identify no-op events
return cls
@no_op_event
class Heartbeat(_BaseEvent[_EventType.Heartbeat]):
event_type: Literal[_EventType.Heartbeat] = _EventType.Heartbeat
node_id: NodeId
@@ -152,6 +169,7 @@ class InstanceReplacedAtomically(_BaseEvent[_EventType.InstanceReplacedAtomicall
instance_to_replace: InstanceId
new_instance_id: InstanceId
# TODO: RunnerCreated
class RunnerStatusUpdated(_BaseEvent[_EventType.RunnerStatusUpdated]):
event_type: Literal[_EventType.RunnerStatusUpdated] = _EventType.RunnerStatusUpdated
@@ -176,6 +194,7 @@ class WorkerStatusUpdated(_BaseEvent[_EventType.WorkerStatusUpdated]):
node_state: NodeStatus
@no_op_event
class ChunkGenerated(_BaseEvent[_EventType.ChunkGenerated]):
event_type: Literal[_EventType.ChunkGenerated] = _EventType.ChunkGenerated
command_id: CommandId

View File

@@ -14,4 +14,3 @@ class RunnerId(ID):
class NodeStatus(str, Enum):
Idle = "Idle"
Running = "Running"
Paused = "Paused"

View File

@@ -16,7 +16,6 @@ class RunnerOpType(str, Enum):
RUNNER_UP = "runner_up"
RUNNER_DOWN = "runner_down"
RUNNER_FAILED = "runner_failed"
DOWNLOAD = "download"
CHAT_COMPLETION = "chat_completion"
RunnerOpT = TypeVar("RunnerOpT", bound=RunnerOpType)
@@ -47,13 +46,6 @@ class RunnerFailedOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_FAILED]]):
op_type: Literal[RunnerOpType.RUNNER_FAILED] = Field(default=RunnerOpType.RUNNER_FAILED, frozen=True)
runner_id: RunnerId
class DownloadOp(BaseRunnerOp[Literal[RunnerOpType.DOWNLOAD]]):
op_type: Literal[RunnerOpType.DOWNLOAD] = Field(default=RunnerOpType.DOWNLOAD, frozen=True)
instance_id: InstanceId
runner_id: RunnerId
shard_metadata: ShardMetadata
hosts: list[Host]
class ExecuteTaskOp(BaseRunnerOp[Literal[RunnerOpType.CHAT_COMPLETION]]):
op_type: Literal[RunnerOpType.CHAT_COMPLETION] = Field(default=RunnerOpType.CHAT_COMPLETION, frozen=True)
runner_id: RunnerId
@@ -68,7 +60,6 @@ RunnerOp = Annotated[
RunnerUpOp,
RunnerDownOp,
RunnerFailedOp,
DownloadOp,
ExecuteTaskOp,
],
Field(discriminator="op_type")

View File

@@ -12,9 +12,8 @@ from shared.types.worker.shards import ShardMetadata
class RunnerStatusType(str, Enum):
Assigned = "Assigned"
Downloading = "Downloading"
Ready = "Ready"
Inactive = "Inactive"
Starting = "Starting"
Loaded = "Loaded"
Running = "Running"
@@ -28,41 +27,30 @@ class BaseRunnerStatus(BaseModel, Generic[RunnerStatusTypeT]):
runner_status: RunnerStatusTypeT
# Emitted by the Master
class AssignedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Assigned]):
runner_status: Literal[RunnerStatusType.Assigned] = Field(default=RunnerStatusType.Assigned)
# Emitted by the Worker
class DownloadingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Downloading]):
runner_status: Literal[RunnerStatusType.Downloading] = Field(default=RunnerStatusType.Downloading)
download_progress: DownloadProgress
# Emitted by the Worker
class ReadyRunnerStatus(BaseRunnerStatus[RunnerStatusType.Ready]):
runner_status: Literal[RunnerStatusType.Ready] = Field(default=RunnerStatusType.Ready)
class InactiveRunnerStatus(BaseRunnerStatus[RunnerStatusType.Inactive]):
runner_status: Literal[RunnerStatusType.Inactive] = Field(default=RunnerStatusType.Inactive)
# Emitted by the Master
class StartingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Starting]):
runner_status: Literal[RunnerStatusType.Starting] = Field(default=RunnerStatusType.Starting)
# Emitted by the Worker
class LoadedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Loaded]):
runner_status: Literal[RunnerStatusType.Loaded] = Field(default=RunnerStatusType.Loaded)
# Emitted by the Worker
class RunningRunnerStatus(BaseRunnerStatus[RunnerStatusType.Running]):
runner_status: Literal[RunnerStatusType.Running] = Field(default=RunnerStatusType.Running)
# Emitted by the Worker
class FailedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Failed]):
runner_status: Literal[RunnerStatusType.Failed] = Field(default=RunnerStatusType.Failed)
error_message: str | None = None
RunnerStatus = Annotated[
AssignedRunnerStatus
| DownloadingRunnerStatus
| ReadyRunnerStatus
DownloadingRunnerStatus
| InactiveRunnerStatus
| StartingRunnerStatus
| LoadedRunnerStatus
| RunningRunnerStatus

35
worker/common.py Normal file
View File

@@ -0,0 +1,35 @@
from copy import deepcopy
from typing import Optional
from pydantic import BaseModel, ConfigDict
from shared.types.common import Host
from shared.types.events import (
InstanceId,
RunnerStatusUpdated,
)
from shared.types.worker.common import RunnerId
from shared.types.worker.runners import (
RunnerStatus,
)
from shared.types.worker.shards import ShardMetadata
from worker.runner.runner_supervisor import RunnerSupervisor
class AssignedRunner(BaseModel):
runner_id: RunnerId
instance_id: InstanceId
shard_metadata: ShardMetadata # just data
hosts: list[Host]
status: RunnerStatus
failures: list[tuple[float, Exception]] = []
runner: Optional[RunnerSupervisor] # set if the runner is 'up'
model_config = ConfigDict(arbitrary_types_allowed=True)
def status_update_event(self) -> RunnerStatusUpdated:
return RunnerStatusUpdated(
runner_id=self.runner_id,
runner_status=deepcopy(self.status),
)

View File

@@ -1,5 +1,3 @@
from pathlib import Path
import pytest
from shared.models.model_meta import get_model_meta
@@ -13,7 +11,7 @@ async def model_meta() -> ModelMetadata:
@pytest.fixture
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path):
def pipeline_shard_meta(model_meta: ModelMetadata):
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:

View File

@@ -1,658 +1,52 @@
import asyncio
import logging
import time
from asyncio import Queue
from copy import deepcopy
from functools import partial
from time import process_time
from typing import AsyncGenerator, Optional
from pydantic import BaseModel, ConfigDict
from shared.apply import apply
from shared.db.sqlite import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import Host, NodeId
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
Event,
InstanceDeleted,
InstanceId,
NodePerformanceMeasured,
RunnerDeleted,
RunnerStatusUpdated,
TaskFailed,
TaskStateUpdated,
NodePerformanceMeasured,
)
from shared.types.profiling import NodePerformanceProfile
from shared.types.state import State
from shared.types.tasks import TaskId, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadProgressData,
)
from shared.types.worker.instances import InstanceStatus
from shared.types.worker.ops import (
AssignRunnerOp,
DownloadOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerFailedOp,
RunnerOp,
RunnerOpType,
RunnerUpOp,
UnassignRunnerOp,
RunnerOp,
)
from shared.types.worker.runners import (
AssignedRunnerStatus,
DownloadingRunnerStatus,
FailedRunnerStatus,
LoadedRunnerStatus,
ReadyRunnerStatus,
RunnerStatus,
RunnerStatusType,
RunningRunnerStatus,
)
from shared.types.worker.shards import ShardMetadata
from shared.utils import get_node_id_keypair
from worker.download.impl_shard_downloader import exo_shard_downloader
from worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from worker.runner.runner_supervisor import RunnerSupervisor
from worker.plan import plan
from worker.utils.profile import start_polling_node_metrics
from worker.worker import Worker
class AssignedRunner(BaseModel):
runner_id: RunnerId
instance_id: InstanceId
shard_metadata: ShardMetadata # just data
hosts: list[Host]
status: RunnerStatus
failures: list[tuple[float, Exception]] = []
runner: Optional[RunnerSupervisor] # set if the runner is 'up'
model_config = ConfigDict(arbitrary_types_allowed=True)
is_downloaded: bool = False
def set_is_downloaded(self, is_downloaded: bool) -> None:
self.is_downloaded = is_downloaded
def status_update_event(self) -> RunnerStatusUpdated:
return RunnerStatusUpdated(
runner_id=self.runner_id,
runner_status=deepcopy(self.status),
)
class Worker:
def __init__(
self,
node_id: NodeId,
logger: logging.Logger,
shard_downloader: ShardDownloader,
worker_events: AsyncSQLiteEventStorage | None,
global_events: AsyncSQLiteEventStorage | None,
):
self.node_id: NodeId = node_id
self.state: State = State()
self.shard_downloader: ShardDownloader = shard_downloader
self.worker_events: AsyncSQLiteEventStorage | None = worker_events # worker_events is None in some tests.
self.global_events: AsyncSQLiteEventStorage | None = global_events
self.logger: logging.Logger = logger
self.assigned_runners: dict[RunnerId, AssignedRunner] = {}
self._task: asyncio.Task[None] | None = None
## Op Executors
async def _execute_assign_op(
self, op: AssignRunnerOp
) -> AsyncGenerator[Event, None]:
'''
Here, we are sure that the model is already downloaded.
This op moves the runner from Assigned -> Ready state.
'''
self.assigned_runners[op.runner_id] = AssignedRunner(
runner_id=op.runner_id,
instance_id=op.instance_id,
shard_metadata=op.shard_metadata,
hosts=op.hosts,
status=AssignedRunnerStatus(),
runner=None,
)
yield self.assigned_runners[op.runner_id].status_update_event()
async def _execute_unassign_op(
self, op: UnassignRunnerOp
) -> AsyncGenerator[Event, None]:
if op.runner_id not in self.assigned_runners:
return
# We can try to do a graceful shutdown of the runner.
runner: RunnerSupervisor | None = self.assigned_runners[op.runner_id].runner
if runner is not None:
await runner.astop()
# This is all we really need:
del self.assigned_runners[op.runner_id]
yield RunnerDeleted(runner_id=op.runner_id)
return
yield
async def _execute_runner_up_op(
self, op: RunnerUpOp, initialize_timeout: Optional[float] = None
) -> AsyncGenerator[Event, None]:
assigned_runner = self.assigned_runners[op.runner_id]
# TODO: This should be dynamic, based on the size of the model.
if not initialize_timeout:
gigabytes_per_second = 10
shard = assigned_runner.shard_metadata
weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes
initialize_timeout = weights_size_kb / (1024**2 * gigabytes_per_second) + 2.0 # Add a constant 2.0 to ensure connection can be made as well
try:
assigned_runner.runner = await asyncio.wait_for(
RunnerSupervisor.create(
model_shard_meta=assigned_runner.shard_metadata,
hosts=assigned_runner.hosts,
logger=self.logger,
),
timeout=initialize_timeout,
)
except TimeoutError as e:
import traceback
tb = traceback.format_exc()
e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}")
async for event in self._fail_runner(e=e, runner_id=op.runner_id):
yield event
return
if assigned_runner.runner.healthy:
assigned_runner.status = LoadedRunnerStatus()
else:
assigned_runner.status = FailedRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
async def _execute_runner_down_op(
self, op: RunnerDownOp
) -> AsyncGenerator[Event, None]:
assigned_runner = self.assigned_runners[op.runner_id]
if isinstance(assigned_runner.runner, RunnerSupervisor):
await assigned_runner.runner.astop()
assigned_runner.runner = None
assigned_runner.status = ReadyRunnerStatus()
yield assigned_runner.status_update_event()
return
async def _execute_runner_failed_op(
self, op: RunnerFailedOp
) -> AsyncGenerator[Event, None]:
'''
We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down.
'''
assigned_runner = self.assigned_runners[op.runner_id]
assigned_runner.status = FailedRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
async def _execute_download_op(
self, op: DownloadOp
) -> AsyncGenerator[Event, None]:
'''
The model needs assigning and then downloading.
This op moves the runner from Assigned -> Downloading -> Ready state.
'''
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata)
if initial_progress.status == "complete":
self.assigned_runners[op.runner_id].set_is_downloaded(True)
self.assigned_runners[op.runner_id].status = DownloadingRunnerStatus(
download_progress=DownloadCompleted(
node_id=self.node_id,
)
)
yield self.assigned_runners[op.runner_id].status_update_event()
self.assigned_runners[op.runner_id].status = ReadyRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
return
initial_status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=initial_progress.total_bytes,
downloaded_bytes=initial_progress.downloaded_bytes
)
)
)
self.assigned_runners[op.runner_id] = AssignedRunner(
runner_id=op.runner_id,
instance_id=op.instance_id,
shard_metadata=op.shard_metadata,
hosts=op.hosts,
status=initial_status,
runner=None,
)
assigned_runner: AssignedRunner = self.assigned_runners[op.runner_id]
yield assigned_runner.status_update_event()
# Download it!
# TODO: we probably want download progress as part of a callback that gets passed to the downloader.
download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue()
def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None:
download_progress_queue.put_nowait(progress)
self.shard_downloader.on_progress(download_progress_callback)
asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata))
# TODO: Dynamic timeout, timeout on no packet update received.
timeout_secs = 10 * 60
start_time = process_time()
last_yield_progress = start_time
while process_time() - start_time < timeout_secs:
progress: RepoDownloadProgress = await download_progress_queue.get()
if progress.status == "complete":
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadCompleted(
node_id=self.node_id,
)
)
yield assigned_runner.status_update_event()
assigned_runner.set_is_downloaded(True)
assigned_runner.status = ReadyRunnerStatus()
yield assigned_runner.status_update_event()
break
elif progress.status == "in_progress":
if process_time() - last_yield_progress > 1:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=progress.total_bytes,
downloaded_bytes=progress.downloaded_bytes,
)
)
)
yield assigned_runner.status_update_event()
last_yield_progress = process_time()
else:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadFailed(
node_id=self.node_id,
error_message=f"Timeout downloading model: {op.shard_metadata.model_meta.model_id}"
)
)
yield assigned_runner.status_update_event()
async def _execute_task_op(
self, op: ExecuteTaskOp
) -> AsyncGenerator[Event, None]:
'''
This is the entry point for a chat completion starting.
While there is only one execute function, it will get called in different ways for runner 0 and runner [1, 2, 3, ...].
Runners [1, 2, 3, ...] will run this method when a task is in 'pending' state.
Runner 0 will run this method when a task is in 'running' state.
TODO: How do we handle the logic of ensuring that n-1 nodes have started their execution before allowing the 0'th runner to start?
This is still a little unclear to me.
'''
assigned_runner = self.assigned_runners[op.runner_id]
async def inner_execute(queue: asyncio.Queue[Event]) -> None:
async def running_callback(queue: asyncio.Queue[Event]) -> None:
# Called when the MLX process has been kicked off
assigned_runner.status = RunningRunnerStatus()
await queue.put(assigned_runner.status_update_event())
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.RUNNING,
))
try:
assert assigned_runner.runner is not None
assert assigned_runner.runner.healthy
async for chunk in assigned_runner.runner.stream_response(
task=op.task,
request_started_callback=partial(running_callback, queue)):
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(ChunkGenerated(
# todo: at some point we will no longer have a bijection between task_id and row_id.
# So we probably want to store a mapping between these two in our Worker object.
command_id=chunk.command_id,
chunk=chunk
))
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.COMPLETE,
))
# After a successful inference:
assigned_runner.status = LoadedRunnerStatus()
await queue.put(assigned_runner.status_update_event())
except Exception as e:
# An exception occurs in the runner supervisor
self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
await queue.put(event)
queue: Queue[Event] = asyncio.Queue()
task = asyncio.create_task(inner_execute(queue))
# TODO: Initial (prefil) timeout can be dynamic
# model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes
try:
# Yield items from the queue
# timeout = 30.
timeout = 3.
while True:
item: Event = await asyncio.wait_for(queue.get(), timeout=timeout)
yield item
timeout = 2.
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus)
):
if isinstance(item.runner_status, LoadedRunnerStatus):
assigned_runner.failures = []
break
except TimeoutError as e:
# Runner supervisor doesn't respond in time; so we put the runner & task into a failed state
self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
yield event
finally:
# Ensure the task is cleaned up
try:
await asyncio.wait_for(task, timeout=5)
except asyncio.TimeoutError:
self.logger.warning("Timed out waiting for task cleanup after inference execution.")
## Operation Planner
async def _execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
## It would be great if we can get rid of this async for ... yield pattern.
match op.op_type:
case RunnerOpType.ASSIGN_RUNNER:
event_generator = self._execute_assign_op(op)
case RunnerOpType.UNASSIGN_RUNNER:
event_generator = self._execute_unassign_op(op)
case RunnerOpType.RUNNER_UP:
event_generator = self._execute_runner_up_op(op)
case RunnerOpType.RUNNER_DOWN:
event_generator = self._execute_runner_down_op(op)
case RunnerOpType.RUNNER_FAILED:
event_generator = self._execute_runner_failed_op(op)
case RunnerOpType.DOWNLOAD:
event_generator = self._execute_download_op(op)
case RunnerOpType.CHAT_COMPLETION:
event_generator = self._execute_task_op(op)
async for event in event_generator:
yield event
## Planning logic
def plan(self, state: State) -> RunnerOp | None:
# Compare state to worker 'mood'
# for runner_id, assigned_runner in self.assigned_runners.items():
# if len(assigned_runner.failures) == 3:
# raise Exception('Too many error occurred in assigned runner - assumed to be recurrent and unrecoverable.\nErrors are as follows: {assigned_runner.failures}')
# First, unassign assigned runners that are no longer in the state.
for runner_id, _ in self.assigned_runners.items():
runner_ids: list[RunnerId] = [
runner_id
for instance in state.instances.values()
for runner_id in instance.shard_assignments.runner_to_shard
]
if runner_id not in runner_ids:
return UnassignRunnerOp(runner_id=runner_id)
for runner_id, assigned_runner in self.assigned_runners.items():
if assigned_runner.runner is not None and \
not assigned_runner.runner.healthy and \
not isinstance(assigned_runner.status, FailedRunnerStatus):
return RunnerFailedOp(runner_id=runner_id)
# Then spin down active runners
for _instance_id, instance in state.instances.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
# We spin down a runner if it's meant to be inactive and it's Loaded.
if runner_id in self.assigned_runners and \
isinstance(self.assigned_runners[runner_id].status, LoadedRunnerStatus) and \
instance.instance_type == InstanceStatus.INACTIVE:
return RunnerDownOp(runner_id=runner_id)
# If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down
# TODO: We need to limit number of retries if we keep failing.
for _instance_id, instance in state.instances.items():
if self.node_id in instance.shard_assignments.node_to_runner and \
instance.shard_assignments.node_to_runner[self.node_id] in self.assigned_runners and \
not isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, ReadyRunnerStatus): # make sure that our runner has not already been spun down into ready state
other_node_in_instance_has_failed = False
for runner_id in instance.shard_assignments.runner_to_shard:
if runner_id in state.runners and \
isinstance(state.runners[runner_id], FailedRunnerStatus) and \
runner_id not in self.assigned_runners:
other_node_in_instance_has_failed= True
if other_node_in_instance_has_failed:
# Spin down *our* runner
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
# If we are failed - and *all of the other nodes have spun down* - then we can spin down too.
for _instance_id, instance in state.instances.items():
if self.node_id in instance.shard_assignments.node_to_runner and \
instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \
instance.shard_assignments.node_to_runner[self.node_id] in self.assigned_runners and \
isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, FailedRunnerStatus):
num_spundown_nodes = 0
for runner_id in instance.shard_assignments.runner_to_shard:
if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \
runner_id not in self.assigned_runners:
num_spundown_nodes += 1
# Suggested:
# if runner_id in state.runners and isinstance(state.runners[runner_id], ReadyRunnerStatus):
# if runner_id != instance.shard_assignments.node_to_runner[self.node_id]:
# num_spundown_nodes += 1
if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1:
# All the other nodes are spun down - so now we can spin down too.
# This also catches the case of 1-node. If there's one node in the instance then we should spin down straight away
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
# Then assign runners we do want
for instance_id, instance in state.instances.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
if runner_id not in self.assigned_runners:
return AssignRunnerOp(
runner_id=runner_id,
instance_id=instance_id,
shard_metadata=instance.shard_assignments.runner_to_shard[runner_id],
hosts=instance.hosts
)
# Then make sure things are downloading.
for instance_id, instance in state.instances.items():
# We should already have asserted that this runner exists
# If it didn't exist then we return a assign_runner op.
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
assert runner_id in self.assigned_runners
runner = self.assigned_runners[runner_id]
if not runner.is_downloaded:
if runner.status.runner_status == RunnerStatusType.Downloading: # Forward compatibility
# TODO: If failed status then we retry
return None
else:
return DownloadOp(
runner_id=runner_id,
instance_id=instance_id,
shard_metadata=instance.shard_assignments.runner_to_shard[runner_id],
hosts=instance.hosts
)
# Then spin up 'ready' runners that should be active
for _instance_id, instance in state.instances.items():
if self.node_id in instance.shard_assignments.node_to_runner and \
self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].runner is None and \
instance.instance_type == InstanceStatus.ACTIVE:
# We are part of this instance, we want it up but it hasn't been spun up yet.
# Need to assert all other runners are ready before we can spin up.
ready_to_spin = True
for runner_id in instance.shard_assignments.node_to_runner.values():
if runner_id in state.runners and state.runners[runner_id].runner_status != RunnerStatusType.Ready:
ready_to_spin = False
if ready_to_spin:
return RunnerUpOp(runner_id=instance.shard_assignments.node_to_runner[self.node_id])
# Then make sure things are running based on tasks.
for instance_id, instance in state.instances.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != self.node_id:
continue
assert runner_id in self.assigned_runners
runner = self.assigned_runners[runner_id]
if runner.status.runner_status != RunnerStatusType.Loaded:
continue # The only previous state to get to Running is from Loaded
for _, task in state.tasks.items():
if task.instance_id == instance_id and (
task.task_status == TaskStatus.PENDING or task.task_status == TaskStatus.FAILED
):
if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1):
return ExecuteTaskOp(runner_id=runner_id, task=task)
else:
# We already know our own status is Loaded. We are rank 0,
# so let's check that all the other runners are running - ready for us to fire the prompt.
running_runner_count = 0
for other_runner_id, other_runner_status in state.runners.items():
if other_runner_id in instance.shard_assignments.node_to_runner.values() and \
isinstance(other_runner_status, RunningRunnerStatus):
running_runner_count += 1
if running_runner_count == runner.shard_metadata.world_size - 1:
return ExecuteTaskOp(runner_id=runner_id, task=task)
return None
async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
assigned_runner = self.assigned_runners[runner_id]
assigned_runner.runner = None
assigned_runner.status = FailedRunnerStatus(error_message=str(e))
assigned_runner.failures.append(
(
time.time(),
e
)
)
# Reset failure count back to 0 when succesful
if len(assigned_runner.failures) >= 3:
# Too many retries. We will emit a DeleteInstance
yield InstanceDeleted(
instance_id=assigned_runner.instance_id
)
yield assigned_runner.status_update_event()
async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
yield TaskStateUpdated(
task_id=task_id,
task_status=TaskStatus.FAILED,
)
yield TaskFailed(
task_id=task_id,
error_type=str(type(e)),
error_message=str(e)
)
async for event in self._fail_runner(e, runner_id):
yield event
async def event_publisher(self, event: Event) -> None:
assert self.worker_events is not None
await self.worker_events.append_events([event], self.node_id)
self.logger.info(f"published event: {event}")
# Handle state updates
async def run(self):
assert self.global_events is not None
async def run(worker_state: Worker):
assert worker_state.global_events is not None
while True:
# 1. get latest events
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
events = await worker_state.global_events.get_events_since(worker_state.state.last_event_applied_idx)
# 2. for each event, apply it to the state and run sagas
for event_from_log in events:
self.state = apply(self.state, event_from_log)
worker_state.state = apply(worker_state.state, event_from_log)
# 3. based on the updated state, we plan & execute an operation.
op: RunnerOp | None = self.plan(self.state)
op: RunnerOp | None = plan(
worker_state.assigned_runners,
worker_state.node_id,
worker_state.state.instances,
worker_state.state.runners,
worker_state.state.tasks,
)
if op is not None:
self.logger.info(f"!!! plan result: {op}")
worker_state.logger.info(f"!!! plan result: {op}")
# run the op, synchronously blocking for now
if op is not None:
try:
async for event in self._execute_op(op):
await self.event_publisher(event)
except Exception as e:
# execeute_task_op already has its own exception handling here. So we assume we had an exception in one of the other op types.
# we therefore just fail the runner.
self.logger.warning(f"Encountered exception when executing worker op {op}: {e}. \n Runner will be spun down and retried.")
async for event in self._fail_runner(
e,
runner_id=op.runner_id,
):
await self.event_publisher(event)
async for event in worker_state.execute_op(op):
await worker_state.event_publisher(event)
await asyncio.sleep(0.01)
if len(events) > 0:
self.logger.info(f"state: {self.state}")
async def main():
@@ -678,7 +72,7 @@ async def main():
worker = Worker(node_id, logger, shard_downloader, event_log_manager.worker_events, event_log_manager.global_events)
await worker.run()
await run(worker)
if __name__ == "__main__":
asyncio.run(main())

205
worker/plan.py Normal file
View File

@@ -0,0 +1,205 @@
from typing import Mapping
from shared.types.common import NodeId
from shared.types.events import (
InstanceId,
)
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.ops import (
AssignRunnerOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerFailedOp,
RunnerOp,
RunnerUpOp,
UnassignRunnerOp,
)
from shared.types.worker.runners import (
DownloadingRunnerStatus,
FailedRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
RunnerStatus,
RunnerStatusType,
RunningRunnerStatus,
)
from worker.common import AssignedRunner
def unassign_runners(instances: Mapping[InstanceId, Instance], state_runners: Mapping[RunnerId, RunnerStatus], assigned_runners: dict[RunnerId, AssignedRunner]) -> UnassignRunnerOp | None:
runner_ids: set[RunnerId] = {
runner_id
for instance in instances.values()
for runner_id in instance.shard_assignments.runner_to_shard
}
for runner_id, _ in assigned_runners.items():
if runner_id not in runner_ids:
return UnassignRunnerOp(runner_id=runner_id)
# If our instance is in 'downloading' or 'assigned' state, then we know the runner is stale. These are part of AssignRunnerOp and should be blocking.
for assigned_runner_id in assigned_runners:
if assigned_runner_id in state_runners and \
isinstance(state_runners[assigned_runner_id], DownloadingRunnerStatus):
return UnassignRunnerOp(runner_id=assigned_runner_id)
return None
def failed_runners(assigned_runners: dict[RunnerId, AssignedRunner]) -> RunnerFailedOp | None:
for runner_id, assigned_runner in assigned_runners.items():
if assigned_runner.runner is not None and \
not assigned_runner.runner.healthy and \
not isinstance(assigned_runner.status, FailedRunnerStatus):
return RunnerFailedOp(runner_id=runner_id)
return None
def spin_down_runners(
instances: Mapping[InstanceId, Instance],
assigned_runners: dict[RunnerId, AssignedRunner],
state_runners: Mapping[RunnerId, RunnerStatus],
worker_node_id: NodeId) -> RunnerDownOp | None:
for _instance_id, instance in instances.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != worker_node_id:
continue
# We spin down a runner if it's meant to be inactive and it's Loaded.
if runner_id in assigned_runners and \
isinstance(assigned_runners[runner_id].status, LoadedRunnerStatus) and \
instance.instance_type == InstanceStatus.INACTIVE:
return RunnerDownOp(runner_id=runner_id)
# If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down
for _instance_id, instance in instances.items():
if worker_node_id in instance.shard_assignments.node_to_runner and \
instance.shard_assignments.node_to_runner[worker_node_id] in assigned_runners and \
not isinstance(assigned_runners[instance.shard_assignments.node_to_runner[worker_node_id]].status, InactiveRunnerStatus): # make sure that our runner has not already been spun down into ready state
other_node_in_instance_has_failed = False
for runner_id in instance.shard_assignments.runner_to_shard:
if runner_id in state_runners and \
isinstance(state_runners[runner_id], FailedRunnerStatus) and \
runner_id not in assigned_runners:
other_node_in_instance_has_failed= True
if other_node_in_instance_has_failed:
# Spin down *our* runner
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[worker_node_id])
# If we are failed - and *all of the other nodes have spun down* - then we can spin down too.
for _instance_id, instance in instances.items():
if worker_node_id in instance.shard_assignments.node_to_runner and \
instance.shard_assignments.node_to_runner[worker_node_id] in state_runners and \
instance.shard_assignments.node_to_runner[worker_node_id] in assigned_runners and \
isinstance(assigned_runners[instance.shard_assignments.node_to_runner[worker_node_id]].status, FailedRunnerStatus):
num_spundown_nodes = 0
for runner_id in instance.shard_assignments.runner_to_shard:
if isinstance(state_runners[runner_id], InactiveRunnerStatus) and \
runner_id not in assigned_runners:
num_spundown_nodes += 1
# Suggested:
# if runner_id in state_runners and isinstance(state.runners[runner_id], InactiveRunnerStatus):
# if runner_id != instance.shard_assignments.node_to_runner[worker_node_id]:
# num_spundown_nodes += 1
if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1:
# All the other nodes are spun down - so now we can spin down too.
# This also catches the case of 1-node. If there's one node in the instance then we should spin down straight away
return RunnerDownOp(runner_id=instance.shard_assignments.node_to_runner[worker_node_id])
return None
def assign_runners(instances: Mapping[InstanceId, Instance], assigned_runners: dict[RunnerId, AssignedRunner], worker_node_id: NodeId) -> AssignRunnerOp | None:
for instance_id, instance in instances.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != worker_node_id:
continue
if runner_id not in assigned_runners:
return AssignRunnerOp(
runner_id=runner_id,
instance_id=instance_id,
shard_metadata=instance.shard_assignments.runner_to_shard[runner_id],
hosts=instance.hosts
)
return None
def spin_up_runners(instances: Mapping[InstanceId, Instance], assigned_runners: dict[RunnerId, AssignedRunner], state_runners: Mapping[RunnerId, RunnerStatus], worker_node_id: NodeId) -> RunnerUpOp | None:
for _instance_id, instance in instances.items():
if worker_node_id in instance.shard_assignments.node_to_runner and \
assigned_runners[instance.shard_assignments.node_to_runner[worker_node_id]].runner is None and \
instance.instance_type == InstanceStatus.ACTIVE:
# We are part of this instance, we want it up but it hasn't been spun up yet.
# Need to assert all other runners are ready before we can spin up.
ready_to_spin = True
for runner_id in instance.shard_assignments.node_to_runner.values():
if runner_id in state_runners and state_runners[runner_id].runner_status != RunnerStatusType.Inactive:
ready_to_spin = False
if ready_to_spin:
return RunnerUpOp(runner_id=instance.shard_assignments.node_to_runner[worker_node_id])
return None
def execute_task_op(instances: Mapping[InstanceId, Instance], assigned_runners: dict[RunnerId, AssignedRunner], state_runners: Mapping[RunnerId, RunnerStatus], tasks: Mapping[TaskId, Task], worker_node_id: NodeId) -> ExecuteTaskOp | None:
for instance_id, instance in instances.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
if node_id != worker_node_id:
continue
assert runner_id in assigned_runners
runner = assigned_runners[runner_id]
if runner.status.runner_status != RunnerStatusType.Loaded:
continue # The only previous state to get to Running is from Loaded
for _, task in tasks.items():
if task.instance_id == instance_id and (
task.task_status == TaskStatus.PENDING or task.task_status == TaskStatus.FAILED
):
if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1):
return ExecuteTaskOp(runner_id=runner_id, task=task)
else:
# We already know our own status is Loaded. We are rank 0,
# so let's check that all the other runners are running - ready for us to fire the prompt.
running_runner_count = 0
for other_runner_id, other_runner_status in state_runners.items():
if other_runner_id in instance.shard_assignments.node_to_runner.values() and \
isinstance(other_runner_status, RunningRunnerStatus):
running_runner_count += 1
if running_runner_count == runner.shard_metadata.world_size - 1:
return ExecuteTaskOp(runner_id=runner_id, task=task)
return None
def plan(assigned_runners: dict[RunnerId, AssignedRunner],
worker_node_id: NodeId,
instances: Mapping[InstanceId, Instance],
state_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task]) -> RunnerOp | None:
# First, unassign assigned runners that are no longer in the state.
if unop := unassign_runners(instances, state_runners, assigned_runners):
return unop
# mark failed runners that are not marked yet as failed
if failed_op := failed_runners(assigned_runners):
return failed_op
# spin down runners that are no longer needed
if down_op := spin_down_runners(instances, assigned_runners, state_runners, worker_node_id):
return down_op
# Then assign runners we do want
if assign_op := assign_runners(instances, assigned_runners, worker_node_id):
return assign_op
# Then spin up 'ready' runners that should be active
if runner_up_op := spin_up_runners(instances, assigned_runners, state_runners, worker_node_id):
return runner_up_op
# Then make sure things are running based on tasks.
if exec_op := execute_task_op(instances, assigned_runners, state_runners, tasks, worker_node_id):
return exec_op
return None

View File

@@ -62,7 +62,7 @@ async def supervisor_read_response(
assert proc.stdout is not None, (
"proc.stdout should not be None when created with stdout=PIPE"
)
line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=10)
line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=180)
line: str = line_bytes.decode("utf-8").strip()
if not line:

View File

@@ -1,36 +1,46 @@
import asyncio
from ipaddress import IPv4Address
from logging import Logger, getLogger
from pathlib import Path
from typing import Awaitable, Callable
from typing import Callable, Optional
import pytest
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.models.model_meta import get_model_meta
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import CommandId, Host, NodeId
from shared.types.common import Host, NodeId
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
from shared.types.tasks import (
ChatCompletionTask,
TaskId,
TaskStatus,
TaskType,
)
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.ops import (
AssignRunnerOp,
RunnerUpOp,
)
from shared.types.worker.runners import RunnerId, ShardAssignments
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import Worker
from worker.tests.constants import (
COMMAND_1_ID,
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
TASK_1_ID,
)
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "Hello, how are you?"
@pytest.fixture
def logger() -> Logger:
return getLogger("test_logger")
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit')
@pytest.fixture
def hosts():
def _hosts(count: int, offset: int = 0) -> list[Host]:
@@ -44,29 +54,8 @@ def hosts():
return _hosts
@pytest.fixture
def hosts_one(hosts: Callable[[int], list[Host]]):
return hosts(1)
@pytest.fixture
def hosts_two(hosts: Callable[[int], list[Host]]):
return hosts(2)
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "Hello, how are you?"
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit')
@pytest.fixture
def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[int, int], PipelineShardMetadata]:
def pipeline_shard_meta(model_meta: ModelMetadata) -> Callable[[int, int], PipelineShardMetadata]:
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
@@ -90,6 +79,37 @@ def pipeline_shard_meta(model_meta: ModelMetadata, tmp_path: Path) -> Callable[[
return _pipeline_shard_meta
@pytest.fixture
def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]]):
from typing import Optional
def _instance(
instance_id: Optional[InstanceId] = None,
node_id: Optional[NodeId] = None,
runner_id: Optional[RunnerId] = None,
model_id: Optional[ModelId] = None,
) -> Instance:
resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID
resolved_node_id = node_id if node_id is not None else NODE_A
resolved_runner_id = runner_id if runner_id is not None else RUNNER_1_ID
resolved_model_id = model_id if model_id is not None else MODEL_A_ID
shard_assignments = ShardAssignments(
model_id=resolved_model_id,
runner_to_shard={
resolved_runner_id: pipeline_shard_meta(1, 0)
},
node_to_runner={resolved_node_id: resolved_runner_id}
)
return Instance(
instance_id=resolved_instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(1)
)
return _instance
@pytest.fixture
def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
"""Creates ChatCompletionParams with the given message"""
@@ -101,10 +121,14 @@ def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
@pytest.fixture
def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
def _chat_completion_task(instance_id: InstanceId, task_id: TaskId) -> ChatCompletionTask:
def _chat_completion_task(instance_id: Optional[InstanceId] = None, task_id: Optional[TaskId] = None) -> ChatCompletionTask:
if instance_id is None:
instance_id = INSTANCE_1_ID
if task_id is None:
task_id = TASK_1_ID
return ChatCompletionTask(
task_id=task_id,
command_id=CommandId(),
command_id=COMMAND_1_ID,
instance_id=instance_id,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
@@ -112,105 +136,4 @@ def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
)
return _chat_completion_task
@pytest.fixture
def node_id() -> NodeId:
"""Shared node ID for tests"""
return NodeId()
@pytest.fixture
def state(node_id: NodeId):
node_status={
node_id: NodeStatus.Idle
}
return State(
node_status=node_status,
)
@pytest.fixture
def logger() -> Logger:
return getLogger("test_logger")
@pytest.fixture
def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts_one: list[Host]):
def _instance(instance_id: InstanceId, node_id: NodeId, runner_id: RunnerId) -> Instance:
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
runner_id: pipeline_shard_meta(1, 0)
},
node_to_runner={node_id: runner_id}
)
return Instance(
instance_id=instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts_one
)
return _instance
@pytest.fixture
async def worker(node_id: NodeId, logger: Logger):
event_log_manager = EventLogManager(EventLogConfig(), logger)
shard_downloader = NoopShardDownloader()
await event_log_manager.initialize()
return Worker(node_id, logger, shard_downloader, worker_events=event_log_manager.global_events, global_events=event_log_manager.global_events)
@pytest.fixture
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]):
"""Fixture that provides a worker with an already assigned runner."""
instance_obj: Instance = instance(InstanceId(), worker.node_id, RunnerId())
# Extract runner_id from shard assignments
runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard))
# Assign the runner
assign_op = AssignRunnerOp(
runner_id=runner_id,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
async for _ in worker._execute_op(assign_op): # type: ignore[misc]
pass
return worker, runner_id, instance_obj
@pytest.fixture
async def worker_with_running_runner(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance]):
"""Fixture that provides a worker with an already assigned runner."""
worker, runner_id, instance_obj = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=runner_id)
async for _ in worker._execute_op(runner_up_op): # type: ignore[misc]
pass
# Is the runner actually running?
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.healthy
return worker, runner_id, instance_obj
@pytest.fixture
def worker_running(logger: Logger) -> Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]]:
async def _worker_running(node_id: NodeId) -> tuple[Worker, AsyncSQLiteEventStorage]:
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
shard_downloader = NoopShardDownloader()
worker = Worker(node_id, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker.run())
return worker, global_events
return _worker_running

26
worker/tests/constants.py Normal file
View File

@@ -0,0 +1,26 @@
from typing import Final
from shared.types.common import CommandId, NodeId
from shared.types.models import ModelId
from shared.types.tasks import TaskId
from shared.types.worker.common import InstanceId, RunnerId
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/TinyLlama-1.1B-Chat-v1.0'
TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555")
TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")

View File

@@ -8,6 +8,7 @@ from worker.download.impl_shard_downloader import exo_shard_downloader
from worker.download.shard_downloader import ShardDownloader
@pytest.mark.slow
@pytest.mark.asyncio
async def test_shard_downloader(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata]):
shard_downloader: ShardDownloader = exo_shard_downloader()

View File

@@ -0,0 +1,70 @@
from logging import Logger
from typing import Callable
import pytest
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import NodeId
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance
from shared.types.worker.ops import (
AssignRunnerOp,
RunnerUpOp,
)
from shared.types.worker.runners import RunnerId
from worker.download.shard_downloader import NoopShardDownloader
from worker.tests.constants import INSTANCE_1_ID, NODE_A, RUNNER_1_ID
from worker.worker import Worker
@pytest.fixture
def user_message():
return "What, according to Douglas Adams, is the meaning of life, the universe and everything?"
@pytest.fixture
async def worker(logger: Logger):
event_log_manager = EventLogManager(EventLogConfig(), logger)
shard_downloader = NoopShardDownloader()
await event_log_manager.initialize()
return Worker(NODE_A, logger, shard_downloader, worker_events=event_log_manager.global_events, global_events=event_log_manager.global_events)
# TODO: instance_id and runner_id are selectable.
@pytest.fixture
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]):
"""Fixture that provides a worker with an already assigned runner."""
instance_id = INSTANCE_1_ID
runner_id = RUNNER_1_ID
instance_obj: Instance = instance(instance_id, worker.node_id, runner_id)
# Assign the runner
assign_op = AssignRunnerOp(
runner_id=runner_id,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
async for _ in worker.execute_op(assign_op):
pass
return worker, instance_obj
@pytest.fixture
async def worker_with_running_runner(worker_with_assigned_runner: tuple[Worker, Instance]):
"""Fixture that provides a worker with an already assigned runner."""
worker, instance_obj = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
async for _ in worker.execute_op(runner_up_op):
pass
# Is the runner actually running?
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.healthy
return worker, instance_obj

View File

@@ -0,0 +1,159 @@
from typing import Callable
import pytest
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
RunnerDeleted,
RunnerStatusUpdated,
TaskStateUpdated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.tasks import ChatCompletionTask, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance, InstanceId
from shared.types.worker.ops import (
AssignRunnerOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerUpOp,
UnassignRunnerOp,
)
from shared.types.worker.runners import (
DownloadingRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
RunningRunnerStatus,
)
from worker.main import Worker
from worker.tests.constants import (
RUNNER_1_ID,
)
from worker.tests.test_handlers.utils import read_events_op
@pytest.mark.asyncio
async def test_assign_op(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]):
instance_obj: Instance = instance(InstanceId(), worker.node_id, RUNNER_1_ID)
assign_op = AssignRunnerOp(
runner_id=RUNNER_1_ID,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[RUNNER_1_ID],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
events = await read_events_op(worker, assign_op)
# We should have a status update saying 'starting'.
assert len(events) == 2
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, DownloadingRunnerStatus)
assert isinstance(events[1], RunnerStatusUpdated)
assert isinstance(events[1].runner_status, InactiveRunnerStatus)
# And the runner should be assigned
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus)
@pytest.mark.asyncio
async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, Instance]):
worker, _ = worker_with_assigned_runner
unassign_op = UnassignRunnerOp(
runner_id=RUNNER_1_ID
)
events = await read_events_op(worker, unassign_op)
# We should have no assigned runners and no events were emitted
assert len(worker.assigned_runners) == 0
assert len(events) == 1
assert isinstance(events[0], RunnerDeleted)
@pytest.mark.asyncio
async def test_runner_up_op(
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask],
):
worker, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
events = await read_events_op(worker, runner_up_op)
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, LoadedRunnerStatus)
# Is the runner actually running?
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.healthy
full_response = ''
async for chunk in supervisor.stream_response(task=chat_completion_task()):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
assert "42" in full_response.lower(), (
f"Expected '42' in response, but got: {full_response}"
)
runner = worker.assigned_runners[RUNNER_1_ID].runner
assert runner is not None
await runner.astop() # Neat cleanup.
@pytest.mark.asyncio
async def test_runner_down_op(worker_with_running_runner: tuple[Worker, Instance]):
worker, _ = worker_with_running_runner
runner_down_op = RunnerDownOp(runner_id=RUNNER_1_ID)
events = await read_events_op(worker, runner_down_op)
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, InactiveRunnerStatus)
@pytest.mark.asyncio
async def test_execute_task_op(
worker_with_running_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask]):
worker, _ = worker_with_running_runner
execute_task_op = ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=chat_completion_task()
)
events = await read_events_op(worker, execute_task_op)
assert len(events) > 20
print(f'{events=}')
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus)
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
assert isinstance(events[-2], TaskStateUpdated)
assert events[-2].task_status == TaskStatus.COMPLETE # It tried to start.
assert isinstance(events[-1], RunnerStatusUpdated)
assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed.
gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)]
text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)]
assert len(text_chunks) == len(events) - 4
output_text = ''.join([x.text for x in text_chunks])
assert '42' in output_text
runner = worker.assigned_runners[RUNNER_1_ID].runner
assert runner is not None
await runner.astop() # Neat cleanup.

View File

@@ -0,0 +1,61 @@
## Tests for worker state handlers
from typing import Callable
import pytest
from shared.types.events import (
RunnerStatusUpdated,
TaskFailed,
TaskStateUpdated,
)
from shared.types.tasks import ChatCompletionTask, TaskStatus
from shared.types.worker.instances import Instance
from shared.types.worker.ops import (
ExecuteTaskOp,
)
from shared.types.worker.runners import (
FailedRunnerStatus,
RunningRunnerStatus,
)
from worker.main import Worker
from worker.tests.constants import RUNNER_1_ID
from worker.tests.test_handlers.utils import read_events_op
@pytest.mark.asyncio
async def test_execute_task_fails(
worker_with_running_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask]):
worker, _ = worker_with_running_runner
task = chat_completion_task()
messages = task.task_params.messages
messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL'
execute_task_op = ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=task
)
events = await read_events_op(worker, execute_task_op)
assert len(events) == 5
print(events)
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus) # It tried to start.
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
assert isinstance(events[2], TaskStateUpdated)
assert events[2].task_status == TaskStatus.FAILED # Task marked as failed.
assert isinstance(events[3], TaskFailed)
assert isinstance(events[4], RunnerStatusUpdated)
assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed.
# TODO: Much more to do here!

View File

@@ -0,0 +1,18 @@
## Tests for worker state handlers
from shared.types.events import (
Event,
)
from shared.types.worker.ops import (
RunnerOp,
)
from worker.main import Worker
async def read_events_op(worker: Worker, op: RunnerOp) -> list[Event]:
events: list[Event] = []
async for event in worker.execute_op(op):
events.append(event)
return events

View File

@@ -0,0 +1,36 @@
import asyncio
from logging import Logger
from typing import Awaitable, Callable
import pytest
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import NodeId
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import run
from worker.worker import Worker
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "What is the capital of Japan?"
@pytest.fixture
def worker_running(logger: Logger) -> Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]]:
async def _worker_running(node_id: NodeId) -> tuple[Worker, AsyncSQLiteEventStorage]:
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
shard_downloader = NoopShardDownloader()
worker = Worker(node_id, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker))
return worker, global_events
return _worker_running

View File

@@ -1,14 +1,11 @@
import asyncio
from logging import Logger
from typing import Awaitable, Callable, Final
import pytest
from typing import Awaitable, Callable
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import CommandId, Host, NodeId
from shared.types.common import Host, NodeId
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
@@ -18,7 +15,7 @@ from shared.types.events import (
)
from shared.types.events.chunks import TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
@@ -26,35 +23,31 @@ from shared.types.worker.instances import (
ShardAssignments,
)
from shared.types.worker.runners import (
AssignedRunnerStatus,
DownloadingRunnerStatus,
# RunningRunnerStatus,
FailedRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
ReadyRunnerStatus,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.common import AssignedRunner
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import AssignedRunner, Worker
from worker.tests.test_worker_integration_utils import read_streaming_response
from worker.main import run
from worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
TASK_2_ID,
)
from worker.tests.test_integration.integration_utils import (
read_streaming_response,
)
from worker.worker import Worker
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
# Define constant IDs for deterministic test cases
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555")
TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
@pytest.fixture
def user_message():
return "What is the capital of Japan?"
async def test_runner_assigned(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
@@ -63,8 +56,6 @@ async def test_runner_assigned(
worker, global_events = await worker_running(NODE_A)
print(worker)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.INACTIVE
@@ -82,22 +73,19 @@ async def test_runner_assigned(
# Ensure the worker has taken the correct action
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, ReadyRunnerStatus)
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus)
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
print(events)
assert len(events) >= 4 # len(events) is 4 if it's already downloaded. It is > 4 if there have to be download events.
assert len(events) >= 3 # len(events) is 4 if it's already downloaded. It is > 4 if there have to be download events.
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, AssignedRunnerStatus)
assert isinstance(events[2].event, RunnerStatusUpdated)
assert isinstance(events[2].event.runner_status, DownloadingRunnerStatus)
assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus)
assert isinstance(events[-1].event, RunnerStatusUpdated)
assert isinstance(events[-1].event.runner_status, ReadyRunnerStatus)
assert isinstance(events[-1].event.runner_status, InactiveRunnerStatus)
# Ensure state is correct
assert isinstance(worker.state.runners[RUNNER_1_ID], ReadyRunnerStatus)
assert isinstance(worker.state.runners[RUNNER_1_ID], InactiveRunnerStatus)
async def test_runner_assigned_active(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
@@ -118,7 +106,7 @@ async def test_runner_assigned_active(
origin=MASTER_NODE_ID
)
await asyncio.sleep(1.0)
await asyncio.sleep(2.0)
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
@@ -126,13 +114,11 @@ async def test_runner_assigned_active(
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len(events) >= 5 # len(events) is 5 if it's already downloaded. It is > 5 if there have to be download events.
assert len(events) >= 4 # len(events) is 5 if it's already downloaded. It is > 5 if there have to be download events.
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, AssignedRunnerStatus)
assert isinstance(events[2].event, RunnerStatusUpdated)
assert isinstance(events[2].event.runner_status, DownloadingRunnerStatus)
assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus)
assert isinstance(events[-2].event, RunnerStatusUpdated)
assert isinstance(events[-2].event.runner_status, ReadyRunnerStatus)
assert isinstance(events[-2].event.runner_status, InactiveRunnerStatus)
assert isinstance(events[-1].event, RunnerStatusUpdated)
assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus)
@@ -201,7 +187,7 @@ async def test_runner_unassigns(
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.5)
await asyncio.sleep(2.0)
# already tested by test_runner_assigned_active
assert len(worker.assigned_runners) == 1
@@ -210,12 +196,11 @@ async def test_runner_unassigns(
# Ensure the correct events have been emitted (creation)
events = await global_events.get_events_since(0)
assert len(events) >= 5
assert len(events) >= 4
assert isinstance(events[-1].event, RunnerStatusUpdated)
assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus)
# Ensure state is correct
print(worker.state)
assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus)
await global_events.append_events(
@@ -227,7 +212,6 @@ async def test_runner_unassigns(
await asyncio.sleep(0.3)
print(worker.state)
assert len(worker.assigned_runners) == 0
# Ensure the correct events have been emitted (deletion)
@@ -236,221 +220,6 @@ async def test_runner_unassigns(
# After deletion, runner should be removed from state.runners
assert len(worker.state.runners) == 0
async def test_runner_inference(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
_worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance_value,
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
async def test_2_runner_inference(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker1.run())
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker2.run())
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
async def test_2_runner_multi_message(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker1.run())
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker2.run())
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
# Task - we have three messages here, which is what the task is about
completion_create_params = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(role="user", content='What is the capital of France?'),
ChatCompletionMessage(role="assistant", content='The capital of France is Paris.'),
ChatCompletionMessage(role="user", content='Ok great. Now write me a haiku about what you can do there.'),
],
stream=True,
)
task = ChatCompletionTask(
task_id=TASK_1_ID,
command_id=CommandId(),
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params
)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert any(keyword in response_string.lower() for keyword in ('kiss', 'paris', 'art', 'love'))
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
async def test_runner_respawn(
@@ -467,10 +236,10 @@ async def test_runner_respawn(
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker1.run())
asyncio.create_task(run(worker1))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker2.run())
asyncio.create_task(run(worker2))
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
@@ -534,21 +303,18 @@ async def test_runner_respawn(
await asyncio.sleep(5.0)
events = await global_events.get_events_since(idx)
print(f'{events=}')
# assert len(events) == 2
assert isinstance(events[0].event, RunnerStatusUpdated)
assert isinstance(events[0].event.runner_status, FailedRunnerStatus)
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, ReadyRunnerStatus)
assert isinstance(events[1].event.runner_status, InactiveRunnerStatus)
assert events[1].event.runner_id == RUNNER_2_ID
assert isinstance(events[2].event, RunnerStatusUpdated)
assert isinstance(events[2].event.runner_status, ReadyRunnerStatus)
assert isinstance(events[2].event.runner_status, InactiveRunnerStatus)
assert events[2].event.runner_id == RUNNER_1_ID
print(worker1.state)
print(worker2.state)
for event in [events[3].event, events[4].event]:
assert isinstance(event, RunnerStatusUpdated)

View File

@@ -0,0 +1,256 @@
import asyncio
from logging import Logger
from typing import Awaitable, Callable
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import CommandId, Host, NodeId
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
TaskCreated,
)
from shared.types.models import ModelId
from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
InstanceStatus,
ShardAssignments,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import run
from worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
)
from worker.tests.test_integration.integration_utils import (
read_streaming_response,
)
from worker.worker import Worker
async def test_runner_inference(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
_worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance_value,
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
async def test_2_runner_inference(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker1))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker2))
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
# TODO: Multi message parallel
async def test_2_runner_multi_message(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker1))
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(run(worker2))
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
# Task - we have three messages here, which is what the task is about
completion_create_params = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(role="user", content='What is the capital of France?'),
ChatCompletionMessage(role="assistant", content='The capital of France is Paris.'),
ChatCompletionMessage(role="user", content='Ok great. Now write me a haiku about what you can do there.'),
],
stream=True,
)
task = ChatCompletionTask(
task_id=TASK_1_ID,
command_id=CommandId(),
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params
)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert any(keyword in response_string.lower() for keyword in ('kiss', 'paris', 'art', 'love'))
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)

View File

@@ -1,7 +1,7 @@
import asyncio
from collections.abc import AsyncGenerator
from types import CoroutineType
from typing import Any, Awaitable, Callable, Final
from typing import Any, Awaitable, Callable
import pytest
from _pytest.monkeypatch import MonkeyPatch
@@ -15,11 +15,9 @@ from shared.types.events import (
InstanceDeleted,
RunnerStatusUpdated,
TaskCreated,
TaskFailed,
TaskStateUpdated,
)
from shared.types.events.chunks import GenerationChunk, TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
@@ -29,20 +27,14 @@ from shared.types.worker.instances import (
from shared.types.worker.runners import FailedRunnerStatus
from worker.main import Worker
from worker.runner.runner_supervisor import RunnerSupervisor
from worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
RUNNER_1_ID,
TASK_1_ID,
)
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
# Define constant IDs for deterministic test cases
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555")
TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
@pytest.fixture
def user_message():
@@ -187,65 +179,65 @@ async def test_stream_response_failed_once(
await asyncio.sleep(0.3)
async def test_stream_response_timeout(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
# TODO: Also a test where we yield a few chunks and then time out.
print('sleeping starting')
await asyncio.sleep(4.)
print('sleeping finished')
return
yield
# async def test_stream_response_timeout(
# monkeypatch: MonkeyPatch,
# worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
# instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
# chat_completion_task: Callable[[InstanceId, TaskId], Task]
# ):
# async def mock_stream_response(
# self: RunnerSupervisor,
# task: Task,
# request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
# ) -> AsyncGenerator[GenerationChunk]:
# # TODO: Also a test where we yield a few chunks and then time out.
# print('sleeping starting')
# await asyncio.sleep(4.)
# print('sleeping finished')
# return
# yield
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
# monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
worker, global_events = await worker_running(NODE_A)
# worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
# instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
# instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
# task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
# await global_events.append_events(
# [
# InstanceCreated(instance=instance_value),
# TaskCreated(task_id=task.task_id, task=task)
# ],
# origin=MASTER_NODE_ID
# )
await asyncio.sleep(7.)
# await asyncio.sleep(7.)
# as we reset the failures back to zero when we have a successful inference.
# # as we reset the failures back to zero when we have a successful inference.
# print('ASSERTION ERR:')
# print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
# # print('ASSERTION ERR:')
# # print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
assert worker.state.tasks[TASK_1_ID].error_type is None
assert worker.state.tasks[TASK_1_ID].error_message is None
# assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
# assert worker.state.tasks[TASK_1_ID].error_type is None
# assert worker.state.tasks[TASK_1_ID].error_message is None
events = await global_events.get_events_since(0)
print(events)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
# events = await global_events.get_events_since(0)
# print(events)
# assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
# assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
# assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
# await global_events.append_events(
# [
# InstanceDeleted(
# instance_id=instance_value.instance_id,
# ),
# ],
# origin=MASTER_NODE_ID
# )
await asyncio.sleep(0.3)
# await asyncio.sleep(0.3)

View File

@@ -0,0 +1,540 @@
from __future__ import annotations
import logging
import pytest
from shared.types.api import ChatCompletionMessage
from shared.types.state import State
from shared.types.tasks import (
ChatCompletionTask,
ChatCompletionTaskParams,
TaskStatus,
TaskType,
)
from shared.types.worker.common import NodeStatus
from shared.types.worker.downloads import (
DownloadPending,
)
from shared.types.worker.instances import InstanceStatus
from shared.types.worker.ops import (
AssignRunnerOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerUpOp,
UnassignRunnerOp,
)
from shared.types.worker.runners import (
DownloadingRunnerStatus,
FailedRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
RunningRunnerStatus,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.common import AssignedRunner
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import Worker
from worker.plan import plan
from worker.tests.constants import (
COMMAND_1_ID,
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
)
from worker.tests.test_plan.test_worker_plan_utils import (
InProcessRunner,
PlanTestCase,
make_downloading_status,
make_model_meta,
make_state,
make_test_case,
)
"""
The idea with these tests is to define declaratively the input and expected output of the worker.plan function.
We initialize a Worker with InProcessRunners. We then construct a State which gets passed to Worker.plan.
We then check what operation is returned by Worker.plan.
Note that the 'self' node will always be NODE_A. This leads to the swapped-around cases when checking failure cases etc.
"""
def _get_test_cases() -> list[PlanTestCase]:
# The `model_path` for `RUNNER_1_ID` must exist for the `DownloadOp` test case to pass validation.
model_a_meta = make_model_meta(MODEL_A_ID)
return [
PlanTestCase(
description="no runners -> no-op",
in_process_runners=[],
state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}),
expected_op=None,
),
# Both 'assigned' and 'downloading' should be blocking ops - so if we are in either of these we should unassign to retry.
# This needs to change when we move to an async worker
make_test_case(
description="runner state assigned, runner is assigned and downloading -> unassign",
runner_specs=[{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': make_downloading_status(NODE_A),
'downloaded': False
}],
instance_status=InstanceStatus.INACTIVE,
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="ready runner, model present -> no-op",
runner_specs=[{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': InactiveRunnerStatus(),
'downloaded': True
}],
instance_status=InstanceStatus.INACTIVE,
expected_op=None,
),
PlanTestCase(
description="runner assigned and not in state -> AssignRunnerOp",
in_process_runners=[],
state=make_state(
runner_specs_per_instance={
INSTANCE_1_ID: [(RUNNER_1_ID, NODE_A, 0, InactiveRunnerStatus())]
},
model_id=MODEL_A_ID,
instance_status=InstanceStatus.ACTIVE, # Either active or inactive should yield the same.
),
expected_op=AssignRunnerOp(
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
shard_metadata=PipelineShardMetadata(
device_rank=0,
world_size=1,
model_meta=model_a_meta,
start_layer=0,
end_layer=1,
n_layers=1,
),
hosts=[]
),
),
PlanTestCase(
description="runner assigned but no longer in state -> UnassignRunnerOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=InactiveRunnerStatus(),
downloaded=False,
)
],
state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}),
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="ready runner (and state up) -> expect RunnerUpOp",
runner_specs=[{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': InactiveRunnerStatus(),
'downloaded': True
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="1 ready, 1 downloading (and state up) -> no-op",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': InactiveRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 1,
'status': DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A)),
'downloaded': False
}
],
tasks=[{
'task_id': TASK_1_ID,
'instance_id': INSTANCE_1_ID,
'status': TaskStatus.PENDING,
'messages': [{'role': 'user', 'content': 'Hello, world!'}]
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=None
),
make_test_case(
description="2 ready runners (and state up) -> expect RunnerUpOp",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': InactiveRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 1,
'status': InactiveRunnerStatus(),
'downloaded': True
}
],
tasks=[{
'task_id': TASK_1_ID,
'instance_id': INSTANCE_1_ID,
'status': TaskStatus.PENDING,
'messages': [{'role': 'user', 'content': 'Hello, world!'}]
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID)
),
make_test_case(
description="loaded runner (and state down) -> expect RunnerDownOp",
runner_specs=[{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': LoadedRunnerStatus(),
'downloaded': True
}],
instance_status=InstanceStatus.INACTIVE,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="failed runner (and state down) -> expect RunnerDownOp",
runner_specs=[{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': FailedRunnerStatus(),
'downloaded': True
}],
instance_status=InstanceStatus.INACTIVE,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="loaded runner, model present, task pending -> expect ExecuteTaskOp",
runner_specs=[{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': LoadedRunnerStatus(),
'downloaded': True
}],
tasks=[{
'task_id': TASK_1_ID,
'instance_id': INSTANCE_1_ID,
'status': TaskStatus.PENDING,
'messages': [{'role': 'user', 'content': 'Hello, world!'}]
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="Hello, world!")]
),
)),
),
# We should only run rank 0 once all other ranks are running.
make_test_case(
description="two loaded runners & task, i'm rank 0 -> no-op",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': LoadedRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 1,
'status': LoadedRunnerStatus(),
'downloaded': True
}
],
tasks=[{
'task_id': TASK_1_ID,
'instance_id': INSTANCE_1_ID,
'status': TaskStatus.PENDING,
'messages': [{'role': 'user', 'content': 'Hello, world!'}]
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=None
),
make_test_case(
description="two loaded runners & task, i'm rank 1 -> expect ExecuteTaskOp on rank 1",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 1,
'status': LoadedRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 0,
'status': LoadedRunnerStatus(),
'downloaded': True
}
],
tasks=[{
'task_id': TASK_1_ID,
'instance_id': INSTANCE_1_ID,
'status': TaskStatus.PENDING,
'messages': [{'role': 'user', 'content': 'Hello, world!'}]
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="Hello, world!")],
),
task_status=TaskStatus.PENDING,
),
),
),
make_test_case(
description="rank 1 loaded, rank 0 ready, i'm rank 0 -> expect ExecuteTaskOp on rank 0",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': LoadedRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 1,
'status': RunningRunnerStatus(),
'downloaded': True
}
],
tasks=[{
'task_id': TASK_1_ID,
'instance_id': INSTANCE_1_ID,
'status': TaskStatus.PENDING,
'messages': [{'role': 'user', 'content': 'Hello, world!'}]
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="Hello, world!")],
),
task_status=TaskStatus.PENDING,
),
),
),
make_test_case(
description="this runner failed (1 node) -> RunnerDownOp",
runner_specs=[{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': FailedRunnerStatus(),
'downloaded': True
}],
instance_status=InstanceStatus.ACTIVE,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
),
make_test_case(
description="other runner failed -> RunnerDownOp",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': LoadedRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 1,
'status': FailedRunnerStatus(),
'downloaded': True
}
],
instance_status=InstanceStatus.ACTIVE,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
),
make_test_case(
description="this runner failed (2 nodes) -> no-op",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': FailedRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 1,
'status': LoadedRunnerStatus(),
'downloaded': True
}
],
instance_status=InstanceStatus.ACTIVE,
expected_op=None
),
make_test_case(
description="this node failed, other node spun down -> RunnerDownOp",
runner_specs=[
{
'runner_id': RUNNER_1_ID,
'node_id': NODE_A,
'device_rank': 0,
'status': FailedRunnerStatus(),
'downloaded': True
},
{
'runner_id': RUNNER_2_ID,
'node_id': NODE_B,
'device_rank': 1,
'status': InactiveRunnerStatus(),
'downloaded': True
}
],
instance_status=InstanceStatus.ACTIVE,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
),
]
# ---------------------------------------------------------------------------
# Parametrised test
# ---------------------------------------------------------------------------
# Pre-compute readable identifiers for each case to avoid lambda typing issues.
@pytest.mark.parametrize(
"case",
# We use a factory to delay test case generation until tmp_path is available.
[pytest.param(c, id=c.id()) for c in _get_test_cases()],
)
def test_worker_plan(case: PlanTestCase) -> None:
"""Exercise Worker.plan across declarative scenarios."""
print(f"----- case: {case.description}")
# Regenerate test cases with the actual tmp_path fixture
test_cases = {c.description: c for c in _get_test_cases()}
case = test_cases[case.description]
node_id = NODE_A
logger = logging.getLogger("test_worker_plan")
shard_downloader = NoopShardDownloader()
worker = Worker(node_id=node_id, shard_downloader=shard_downloader, worker_events=None, global_events=None, logger=logger)
runner_config: InProcessRunner
for runner_config in case.in_process_runners:
if len(case.state.instances) == 1:
instance_id = next(iter(case.state.instances))
shard_assignments = case.state.instances[instance_id].shard_assignments
shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id]
# Only add this runner if it belongs to our node
runner_node = None
for node, runner in shard_assignments.node_to_runner.items():
if runner == runner_config.runner_id:
runner_node = node
break
if runner_node != node_id:
# This runner belongs to a different node, skip it
continue
elif len(case.state.instances) == 0:
shard_metadata = PipelineShardMetadata(
device_rank=runner_config.device_rank,
world_size=1,
model_meta=make_model_meta(runner_config.model_id),
start_layer=0,
end_layer=1,
n_layers=1,
)
else:
raise Exception('test_worker_plan not currently designed to have more than 1 instance.')
assigned_runner = AssignedRunner(
runner_id=runner_config.runner_id,
instance_id=runner_config.instance_id,
shard_metadata=shard_metadata,
hosts=[],
status=runner_config.status,
runner=None,
)
worker.assigned_runners[runner_config.runner_id] = assigned_runner
op = plan(worker.assigned_runners,
NODE_A,
case.state.instances,
case.state.runners,
case.state.tasks,
)
assert op == case.expected_op

View File

@@ -0,0 +1,272 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, NotRequired, Optional, TypedDict
from typing_extensions import Literal
from shared.models.model_cards import MODEL_CARDS, ModelCard
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import CommandId, NodeId
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
from shared.types.worker.common import InstanceId, NodeStatus, RunnerId
from shared.types.worker.downloads import DownloadOngoing, DownloadProgressData
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.ops import RunnerOp
from shared.types.worker.runners import (
DownloadingRunnerStatus,
RunnerStatus,
RunningRunnerStatus,
ShardAssignments,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.tests.constants import COMMAND_1_ID, INSTANCE_1_ID, MODEL_A_ID
class RunnerSpecDict(TypedDict):
"""Type definition for runner specification dictionaries."""
runner_id: RunnerId
node_id: NodeId
device_rank: int
status: RunnerStatus
downloaded: NotRequired[bool] # defaults to True if not provided
class MessageDict(TypedDict):
"""Type definition for message dictionaries."""
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: NotRequired[str | None]
name: NotRequired[str | None]
tool_calls: NotRequired[list[dict[str, str]] | None]
tool_call_id: NotRequired[str | None]
function_call: NotRequired[dict[str, str] | None]
class TaskSpecDict(TypedDict):
"""Type definition for task specification dictionaries."""
task_id: TaskId
instance_id: NotRequired[InstanceId] # defaults to function parameter if not provided
command_id: NotRequired[CommandId] # defaults to COMMAND_1_ID if not provided
status: NotRequired[TaskStatus] # defaults to TaskStatus.PENDING if not provided
model: NotRequired[str] # defaults to model_id if not provided
messages: NotRequired[list[MessageDict]] # defaults to [{'role': 'user', 'content': 'Hello, world!'}] if not provided
@dataclass(slots=True, frozen=True)
class InProcessRunner:
"""Minimal description of a runner's in-process state."""
runner_id: RunnerId
instance_id: InstanceId
model_id: ModelId
status: RunnerStatus
downloaded: bool
device_rank: int = 0
@dataclass(slots=True, frozen=True)
class PlanTestCase:
"""Table-driven description of an entire planning scenario."""
description: str
state: State
in_process_runners: List[InProcessRunner]
expected_op: Optional[RunnerOp]
def id(self) -> str: # noqa: D401
return self.description.replace(" ", "_")
def make_shard_metadata(device_rank: int, world_size: int, model_id: ModelId = MODEL_A_ID) -> PipelineShardMetadata:
"""Create PipelineShardMetadata with proper layer assignments based on device_rank and world_size."""
total_layers = world_size # For simplicity in tests, total_layers = world_size
if world_size == 1:
start_layer = 0
end_layer = 1
n_layers = 1
else:
# For multi-device setup, each device gets one layer
start_layer = device_rank
end_layer = device_rank + 1
n_layers = total_layers
return PipelineShardMetadata(
device_rank=device_rank,
world_size=world_size,
model_meta=make_model_meta(model_id),
start_layer=start_layer,
end_layer=end_layer,
n_layers=n_layers,
)
def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
"""Factory for a *Downloading* status with placeholder progress."""
return DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=node_id,
download_progress=DownloadProgressData(total_bytes=1, downloaded_bytes=0),
)
)
def make_model_meta(
model_id: str
) -> ModelMetadata:
model_card: ModelCard
for card in MODEL_CARDS.values():
if card.model_id == model_id:
model_card = card
return ModelMetadata(
model_id=model_id,
pretty_name=model_card.model_id,
storage_size_kilobytes=10**6,
n_layers=16,
)
raise Exception(f'Unknown model_id passed: {model_id}')
## Alternatively, if we are ok for this method to be async:
# await _get_model_meta(model_id)
def make_instance(
instance_id: InstanceId,
runner_specs: list[tuple[RunnerId, NodeId, int, RunnerStatus]],
model_id: ModelId = MODEL_A_ID,
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, NodeStatus]]:
"""Creates an instance with one or more runners."""
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
world_size = len(runner_specs)
for runner_id, node_id, device_rank, _ in runner_specs:
shard_metadata = make_shard_metadata(
device_rank,
world_size,
model_id
)
runner_to_shard[runner_id] = shard_metadata
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
instance = Instance(
instance_id=instance_id,
instance_type=instance_status,
shard_assignments=shard_assignments,
hosts=[],
)
# Currently nodes are only ever idle - as if they were running we would be blocking - so we wouldn't be running plan()
# node_statuses = {node_id: NodeStatus.Idle for _, node_id, _, _ in runner_specs}
node_statuses: dict[NodeId, NodeStatus] = {}
for _runner_id, node_id, _, status in runner_specs:
if isinstance(status, RunningRunnerStatus):
node_statuses[node_id] = NodeStatus.Running
else:
node_statuses[node_id] = NodeStatus.Idle
runner_statuses = {runner_id: status for runner_id, _, _, status in runner_specs}
return instance, runner_statuses, node_statuses
def make_state(
runner_specs_per_instance: dict[InstanceId, list[tuple[RunnerId, NodeId, int, RunnerStatus]]],
tasks: dict[TaskId, ChatCompletionTask] | None = None,
model_id: ModelId = MODEL_A_ID,
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
) -> State:
"""Builds a full State from runner specs per instance, tasks, and defaults."""
if tasks is None:
tasks = {}
instances: dict[InstanceId, Instance] = {}
all_runner_statuses: dict[RunnerId, RunnerStatus] = {}
all_node_statuses: dict[NodeId, NodeStatus] = {}
for inst_id, specs in runner_specs_per_instance.items():
# Build per-instance data using make_instance
instance, runner_statuses, node_statuses = make_instance(
instance_id=inst_id,
runner_specs=specs,
model_id=model_id,
instance_status=instance_status,
)
instances[inst_id] = instance
all_runner_statuses.update(runner_statuses)
all_node_statuses.update(node_statuses)
return State(
node_status=all_node_statuses,
instances=instances,
runners=all_runner_statuses,
tasks=tasks,
)
def make_test_case(
description: str,
runner_specs: list[RunnerSpecDict],
tasks: list[TaskSpecDict] | None = None,
expected_op: Optional[RunnerOp] = None,
instance_id: InstanceId = INSTANCE_1_ID,
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
model_id: ModelId = MODEL_A_ID,
command_id: CommandId = COMMAND_1_ID, # Default for tasks
) -> PlanTestCase:
"""Builds a PlanTestCase from high-level specs."""
if tasks is None:
tasks = []
# Convert runner_specs to tuple format for make_instance
specs_tuple = [
(r['runner_id'], r['node_id'], r['device_rank'], r['status'])
for r in runner_specs
]
# Build state using make_state (wrap single instance)
state_tasks: dict[TaskId, ChatCompletionTask] = {}
for t in tasks:
task = ChatCompletionTask(
instance_id=instance_id,
task_id=t['task_id'],
command_id=t.get('command_id', command_id),
task_type=TaskType.CHAT_COMPLETION,
task_status=t.get('status', TaskStatus.PENDING),
task_params=ChatCompletionTaskParams(
model=t.get('model', str(model_id)),
messages=[ChatCompletionMessage(**m) for m in t.get('messages', [{'role': 'user', 'content': 'Hello, world!'}])],
),
)
state_tasks[t['task_id']] = task
state = make_state(
runner_specs_per_instance={instance_id: specs_tuple},
tasks=state_tasks,
model_id=model_id,
instance_status=instance_status,
)
# Build in_process_runners with downloaded (default True if missing)
in_process_runners = [
InProcessRunner(
runner_id=r['runner_id'],
instance_id=instance_id,
model_id=model_id,
status=r['status'],
downloaded=r.get('downloaded', True),
device_rank=r['device_rank'],
) for r in runner_specs
]
return PlanTestCase(
description=description,
state=state,
in_process_runners=in_process_runners,
expected_op=expected_op,
)

View File

@@ -9,13 +9,13 @@ from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import Host, NodeId
from shared.types.events import InstanceCreated, InstanceDeleted
from shared.types.models import ModelId
from shared.types.tasks import Task
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
from shared.types.worker.runners import FailedRunnerStatus
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import Worker
from worker.main import run
from worker.worker import Worker
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
@@ -42,7 +42,6 @@ async def check_runner_connection(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, str], Task],
) -> bool:
# Track all tasks and workers for cleanup
tasks: list[asyncio.Task[None]] = []
@@ -64,7 +63,7 @@ async def check_runner_connection(
global_events=global_events,
)
workers.append(worker1)
task1 = asyncio.create_task(worker1.run())
task1 = asyncio.create_task(run(worker1))
tasks.append(task1)
worker2 = Worker(
@@ -75,7 +74,7 @@ async def check_runner_connection(
global_events=global_events,
)
workers.append(worker2)
task2 = asyncio.create_task(worker2.run())
task2 = asyncio.create_task(run(worker2))
tasks.append(task2)
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
@@ -151,39 +150,41 @@ async def check_runner_connection(
# Check Running status
def test_runner_connection_stress(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, str], Task],
) -> None:
total_runs = 100
successes = 0
# # not now.
for _ in range(total_runs):
# Create a fresh event loop for each iteration
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# def test_runner_connection_stress(
# logger: Logger,
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
# hosts: Callable[[int], list[Host]],
# chat_completion_task: Callable[[InstanceId, str], Task],
# ) -> None:
# total_runs = 100
# successes = 0
try:
result = loop.run_until_complete(check_runner_connection(
logger=logger,
pipeline_shard_meta=pipeline_shard_meta,
hosts=hosts,
chat_completion_task=chat_completion_task,
))
if result:
successes += 1
finally:
# Cancel all running tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# for _ in range(total_runs):
# # Create a fresh event loop for each iteration
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# Run the event loop briefly to allow cancellation to complete
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
# try:
# result = loop.run_until_complete(check_runner_connection(
# logger=logger,
# pipeline_shard_meta=pipeline_shard_meta,
# hosts=hosts,
# chat_completion_task=chat_completion_task,
# ))
# if result:
# successes += 1
# finally:
# # Cancel all running tasks
# pending = asyncio.all_tasks(loop)
# for task in pending:
# task.cancel()
# Close the event loop
loop.close()
# # Run the event loop briefly to allow cancellation to complete
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
print(f"Runner connection successes: {successes} / {total_runs}")
# # Close the event loop
# loop.close()
# print(f"Runner connection successes: {successes} / {total_runs}")

View File

@@ -1,4 +1,3 @@
from pathlib import Path
from typing import Callable, TypeVar
from pydantic import BaseModel, TypeAdapter
@@ -28,7 +27,6 @@ def assert_equal_serdes(obj: T, typeadapter: TypeAdapter[T]):
def test_supervisor_setup_message_serdes(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
tmp_path: Path,
):
setup_message = SetupMessage(
model_shard_meta=pipeline_shard_meta(1, 0),

View File

@@ -10,13 +10,13 @@ from shared.types.events import (
)
from shared.types.events._events import RunnerStatusUpdated
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance, InstanceId
from shared.types.worker.ops import (
RunnerUpOp,
)
from shared.types.worker.runners import FailedRunnerStatus
from worker.main import Worker
from worker.tests.constants import RUNNER_1_ID
# To enable this test, run pytest with: ENABLE_SPINUP_TIMEOUT_TEST=true pytest
@@ -26,13 +26,13 @@ from worker.main import Worker
)
@pytest.mark.asyncio
async def test_runner_up_op_timeout(
worker_with_assigned_runner: tuple[Worker, RunnerId, Instance],
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
monkeypatch: pytest.MonkeyPatch
):
worker, runner_id, _ = worker_with_assigned_runner
worker, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=runner_id)
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
# _execute_runner_up_op should throw a TimeoutError with a short timeout
events: list[Event] = []

View File

@@ -1,6 +1,5 @@
import asyncio
from logging import Logger
from pathlib import Path
from typing import Callable
import pytest
@@ -30,7 +29,6 @@ async def test_supervisor_single_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
tmp_path: Path,
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
@@ -70,7 +68,6 @@ async def test_supervisor_two_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
tmp_path: Path,
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
@@ -133,7 +130,6 @@ async def test_supervisor_early_stopping(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
tmp_path: Path,
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
@@ -189,7 +185,6 @@ async def test_supervisor_handles_terminated_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
logger: Logger,
tmp_path: Path,
):
"""Test that the supervisor handles a terminated runner"""
model_shard_meta = pipeline_shard_meta(1, 0)
@@ -214,7 +209,6 @@ async def test_supervisor_handles_terminated_runner(
async def test_supervisor_handles_killed_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
tmp_path: Path,
logger: Logger,
):
"""Test that the supervisor handles a killed runner"""

View File

@@ -1,237 +0,0 @@
## Tests for worker state handlers
from pathlib import Path
from typing import Callable
import pytest
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
Event,
RunnerDeleted,
RunnerStatusUpdated,
TaskFailed,
TaskStateUpdated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance, InstanceId
from shared.types.worker.ops import (
AssignRunnerOp,
DownloadOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerUpOp,
UnassignRunnerOp,
)
from shared.types.worker.runners import (
AssignedRunnerStatus,
FailedRunnerStatus,
LoadedRunnerStatus,
ReadyRunnerStatus,
RunningRunnerStatus,
)
from worker.main import Worker
@pytest.fixture
def user_message():
"""Override the default message to ask about France's capital"""
return "What, according to Douglas Adams, is the meaning of life, the universe and everything?"
@pytest.mark.asyncio
async def test_assign_op(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], tmp_path: Path):
runner_id = RunnerId()
instance_obj: Instance = instance(InstanceId(), worker.node_id, runner_id)
assign_op = AssignRunnerOp(
runner_id=runner_id,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
events: list[Event] = []
async for event in worker._execute_op(assign_op): # type: ignore[misc]
events.append(event)
# We should have a status update saying 'starting'.
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, AssignedRunnerStatus)
# And the runner should be assigned
assert runner_id in worker.assigned_runners
assert isinstance(worker.assigned_runners[runner_id].status, AssignedRunnerStatus)
@pytest.mark.asyncio
async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path):
worker, runner_id, _ = worker_with_assigned_runner
unassign_op = UnassignRunnerOp(
runner_id=runner_id
)
events: list[Event] = []
async for event in worker._execute_op(unassign_op): # type: ignore[misc]
events.append(event)
# We should have no assigned runners and no events were emitted
assert len(worker.assigned_runners) == 0
assert len(events) == 1
assert isinstance(events[0], RunnerDeleted)
@pytest.mark.asyncio
async def test_runner_up_op(
worker_with_assigned_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
tmp_path: Path
):
worker, runner_id, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=runner_id)
events: list[Event] = []
async for event in worker._execute_op(runner_up_op): # type: ignore[misc]
events.append(event)
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, LoadedRunnerStatus)
# Is the runner actually running?
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.healthy
full_response = ''
async for chunk in supervisor.stream_response(task=chat_completion_task(InstanceId(), TaskId())):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
assert "42" in full_response.lower(), (
f"Expected '42' in response, but got: {full_response}"
)
runner = worker.assigned_runners[runner_id].runner
assert runner is not None
await runner.astop() # Neat cleanup.
@pytest.mark.asyncio
async def test_runner_down_op(worker_with_running_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
runner_down_op = RunnerDownOp(runner_id=runner_id)
events: list[Event] = []
async for event in worker._execute_op(runner_down_op): # type: ignore[misc]
events.append(event)
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, ReadyRunnerStatus)
@pytest.mark.asyncio
async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path):
worker, runner_id, instance_obj = worker_with_assigned_runner
print(f'{worker.assigned_runners=}')
download_op = DownloadOp(
instance_id=instance_obj.instance_id,
runner_id=runner_id,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
)
events: list[Event] = []
async for event in worker._execute_op(download_op): # type: ignore[misc]
events.append(event)
# Should give download status and then a final download status with DownloadCompleted
print(events)
@pytest.mark.asyncio
async def test_execute_task_op(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task], tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
execute_task_op = ExecuteTaskOp(
runner_id=runner_id,
task=chat_completion_task(InstanceId(), TaskId())
)
events: list[Event] = []
async for event in worker._execute_op(execute_task_op): # type: ignore[misc]
events.append(event)
assert len(events) > 20
print(f'{events=}')
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus)
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
assert isinstance(events[-2], TaskStateUpdated)
assert events[-2].task_status == TaskStatus.COMPLETE # It tried to start.
assert isinstance(events[-1], RunnerStatusUpdated)
assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed.
gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)]
text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)]
assert len(text_chunks) == len(events) - 4
output_text = ''.join([x.text for x in text_chunks])
assert '42' in output_text
runner = worker.assigned_runners[runner_id].runner
assert runner is not None
await runner.astop() # Neat cleanup.
@pytest.mark.asyncio
async def test_execute_task_fails(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task], tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
task = chat_completion_task(InstanceId(), TaskId())
messages = task.task_params.messages
messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL'
execute_task_op = ExecuteTaskOp(
runner_id=runner_id,
task=task
)
events: list[Event] = []
async for event in worker._execute_op(execute_task_op): # type: ignore[misc]
events.append(event)
assert len(events) == 5
print(events)
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus) # It tried to start.
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
assert isinstance(events[2], TaskStateUpdated)
assert events[2].task_status == TaskStatus.FAILED # Task marked as failed.
assert isinstance(events[3], TaskFailed)
assert isinstance(events[4], RunnerStatusUpdated)
assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed.

View File

@@ -1,913 +0,0 @@
from __future__ import annotations
import logging
import tempfile
from pathlib import Path
import pytest
from shared.types.api import ChatCompletionMessage
from shared.types.state import State
from shared.types.tasks import (
ChatCompletionTask,
ChatCompletionTaskParams,
TaskStatus,
TaskType,
)
from shared.types.worker.common import NodeStatus
from shared.types.worker.downloads import DownloadPending
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.ops import (
AssignRunnerOp,
DownloadOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerUpOp,
UnassignRunnerOp,
)
from shared.types.worker.runners import (
AssignedRunnerStatus,
DownloadingRunnerStatus,
FailedRunnerStatus,
LoadedRunnerStatus,
ReadyRunnerStatus,
RunningRunnerStatus,
ShardAssignments,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.download_utils import build_model_path
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import AssignedRunner, Worker
from .test_worker_plan_utils import (
COMMAND_1_ID,
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
InProcessRunner,
PlanTestCase,
make_downloading_status,
make_model_meta,
make_shard_metadata,
)
"""
The idea with these tests is to define declaratively the input and expected output of the worker.plan function.
We initialize a Worker with InProcessRunners. We then construct a State which gets passed to Worker.plan.
We then check what operation is returned by Worker.plan.
"""
def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
# The `model_path` for `RUNNER_1_ID` must exist for the `DownloadOp` test case to pass validation.
(tmp_path / f"model_for_runner_{RUNNER_1_ID}").mkdir(exist_ok=True, parents=True)
model_a_meta = make_model_meta(MODEL_A_ID)
return [
PlanTestCase(
description="no runners -> no-op",
in_process_runners=[],
state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}),
expected_op=None,
),
# I don't think this should ever happen, as if it's currently downloading then the worker loop will be blocked
# Potentially useful for future compatibility when worker becomes non-blocking
PlanTestCase(
description="runner state assigned, runner is assigned and downloading -> no-op",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=make_downloading_status(NODE_A),
downloaded=False,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
),
expected_op=None,
),
PlanTestCase(
description="runner state downloading, runner is downloading -> no-op",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=make_downloading_status(NODE_A),
downloaded=False,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
),
expected_op=None,
),
PlanTestCase(
description="ready runner, model present -> no-op",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=ReadyRunnerStatus(),
downloaded=True,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus()},
),
expected_op=None,
),
PlanTestCase(
description="runner assigned and not in state -> AssignRunnerOp",
in_process_runners=[],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE, # Either active or inactive should yield the same.
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: AssignedRunnerStatus()},
),
expected_op=AssignRunnerOp(
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
shard_metadata=PipelineShardMetadata(
device_rank=0,
world_size=1,
model_meta=model_a_meta,
start_layer=0,
end_layer=1,
n_layers=1,
),
hosts=[]
),
),
PlanTestCase(
description="runner assigned but no longer in state -> UnassignRunnerOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=AssignedRunnerStatus(),
downloaded=False,
)
],
state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}),
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
),
PlanTestCase(
description="runner state assigned, runner is assigned, not downloaded -> expect DownloadOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=AssignedRunnerStatus(),
downloaded=False,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: AssignedRunnerStatus()},
),
expected_op=DownloadOp(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
shard_metadata=PipelineShardMetadata(
device_rank=0,
world_size=1,
model_meta=model_a_meta,
start_layer=0,
end_layer=1,
n_layers=1,
),
hosts=[],
),
),
PlanTestCase(
description="ready runner (and state up) -> expect RunnerUpOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=ReadyRunnerStatus(),
downloaded=True,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus()},
tasks={},
),
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
),
PlanTestCase(
description="1 ready, 1 downloading (and state up) -> no-op",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=ReadyRunnerStatus(),
downloaded=True,
device_rank=0,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=DownloadingRunnerStatus(
download_progress=DownloadPending(node_id=NODE_A)
),
downloaded=False,
device_rank=1,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A))},
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
),
expected_op=None
),
PlanTestCase(
description="2 ready runners (and state up) -> expect RunnerUpOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=ReadyRunnerStatus(),
downloaded=True,
device_rank=0,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=ReadyRunnerStatus(),
downloaded=True,
device_rank=1,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
),
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID)
),
PlanTestCase(
description="loaded runner (and state down) -> expect RunnerDownOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus()},
tasks={},
),
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
PlanTestCase(
description="failed runner (and state down) -> expect RunnerDownOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=FailedRunnerStatus(),
downloaded=True,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.INACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: FailedRunnerStatus()},
tasks={},
),
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
PlanTestCase(
description="loaded runner, model present, task pending -> expect ExecuteTaskOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
)
],
state=State(
node_status={NODE_A: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus()},
tasks={
TASK_1_ID: ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
ChatCompletionMessage(
role="user",
content="Hello, world!"
)
]
),
instance_id=INSTANCE_1_ID
)
},
),
expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="Hello, world!")]
),
)),
),
PlanTestCase(
# We should only run rank 0 once all other ranks are running.
description="two loaded runners & task, i'm rank 0 -> no-op",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
device_rank=0,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
device_rank=1,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
),
expected_op=None
),
PlanTestCase(
description="two loaded runners & task, i'm rank 1 -> expect ExecuteTaskOp on rank 1",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
device_rank=1,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
device_rank=0,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=1, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=0, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
),
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="Hello, world!")],
),
task_status=TaskStatus.PENDING,
),
),
),
PlanTestCase(
description="rank 1 loaded, rank 0 ready, i'm rank 0 -> expect ExecuteTaskOp on rank 0",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
device_rank=0,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=RunningRunnerStatus(),
downloaded=True,
device_rank=1,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Running},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: RunningRunnerStatus()},
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
),
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="Hello, world!")],
),
task_status=TaskStatus.PENDING,
),
),
),
PlanTestCase(
description="other runner failed -> RunnerDownOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
device_rank=0,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=FailedRunnerStatus(),
downloaded=True,
device_rank=1,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: FailedRunnerStatus()},
),
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
),
PlanTestCase(
description="this runner failed (1 node) -> RunnerDownOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=FailedRunnerStatus(),
downloaded=True,
device_rank=0,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1),
},
node_to_runner={NODE_A: RUNNER_1_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: FailedRunnerStatus()},
),
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
),
PlanTestCase(
description="this runner failed (2 nodes) -> no-op",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=FailedRunnerStatus(),
downloaded=True,
device_rank=0,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=LoadedRunnerStatus(),
downloaded=True,
device_rank=1,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
),
expected_op=None
),
PlanTestCase(
description="this node failed, other node spun down -> RunnerDownOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=FailedRunnerStatus(),
downloaded=True,
device_rank=0,
),
InProcessRunner(
runner_id=RUNNER_2_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=ReadyRunnerStatus(),
downloaded=True,
device_rank=1,
),
],
state=State(
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
instances={
INSTANCE_1_ID: Instance(
instance_type=InstanceStatus.ACTIVE,
instance_id=INSTANCE_1_ID,
shard_assignments=ShardAssignments(
model_id=MODEL_A_ID,
runner_to_shard={
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
),
hosts=[]
)
},
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
),
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
),
]
# ---------------------------------------------------------------------------
# Parametrised test
# ---------------------------------------------------------------------------
# Pre-compute readable identifiers for each case to avoid lambda typing issues.
@pytest.mark.parametrize(
"case",
# We use a factory to delay test case generation until tmp_path is available.
[pytest.param(c, id=c.id()) for c in _get_test_cases(Path(tempfile.TemporaryDirectory().name))],
)
def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Exercise Worker.plan across declarative scenarios."""
print(f"----- case: {case.description}")
# Regenerate test cases with the actual tmp_path fixture
test_cases = {c.description: c for c in _get_test_cases(tmp_path)}
case = test_cases[case.description]
node_id = NODE_A
logger = logging.getLogger("test_worker_plan")
shard_downloader = NoopShardDownloader()
worker = Worker(node_id=node_id, shard_downloader=shard_downloader, worker_events=None, global_events=None, logger=logger)
path_downloaded_map: dict[str, bool] = {}
runner_config: InProcessRunner
for runner_config in case.in_process_runners:
model_path = tmp_path / f"model_for_runner_{runner_config.runner_id}"
model_path.mkdir(exist_ok=True, parents=True)
if len(case.state.instances) == 1:
instance_id = next(iter(case.state.instances))
shard_assignments = case.state.instances[instance_id].shard_assignments
shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id]
# Only add this runner if it belongs to our node
runner_node = None
for node, runner in shard_assignments.node_to_runner.items():
if runner == runner_config.runner_id:
runner_node = node
break
if runner_node != node_id:
# This runner belongs to a different node, skip it
continue
elif len(case.state.instances) == 0:
shard_metadata = PipelineShardMetadata(
device_rank=runner_config.device_rank,
world_size=1,
model_meta=make_model_meta(runner_config.model_id),
start_layer=0,
end_layer=1,
n_layers=1,
)
else:
raise Exception('test_worker_plan not currently designed to have more than 1 instance.')
assigned_runner = AssignedRunner(
runner_id=runner_config.runner_id,
instance_id=runner_config.instance_id,
shard_metadata=shard_metadata,
hosts=[],
status=runner_config.status,
runner=None,
is_downloaded=runner_config.downloaded
)
worker.assigned_runners[runner_config.runner_id] = assigned_runner
path_downloaded_map[str(build_model_path(shard_metadata.model_meta.model_id))] = runner_config.downloaded
op = worker.plan(case.state)
assert op == case.expected_op

View File

@@ -1,195 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Final, List, Optional
from shared.models.model_cards import MODEL_CARDS, ModelCard
from shared.types.common import CommandId, NodeId
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
from shared.types.tasks import TaskId
from shared.types.worker.common import InstanceId, NodeStatus, RunnerId
from shared.types.worker.downloads import DownloadOngoing, DownloadProgressData
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.ops import RunnerOp
from shared.types.worker.runners import (
AssignedRunnerStatus,
DownloadingRunnerStatus,
RunnerStatus,
ShardAssignments,
)
from shared.types.worker.shards import PipelineShardMetadata
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
# Define constant IDs for deterministic test cases
RUNNER_1_ID: Final[RunnerId] = RunnerId("cccccccc-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
INSTANCE_1_ID: Final[InstanceId] = InstanceId()
RUNNER_2_ID: Final[RunnerId] = RunnerId("dddddddd-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
INSTANCE_2_ID: Final[InstanceId] = InstanceId()
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final[TaskId] = TaskId()
COMMAND_1_ID: Final[CommandId] = CommandId()
@dataclass(slots=True, frozen=True)
class InProcessRunner:
"""Minimal description of a runner's in-process state."""
runner_id: RunnerId
instance_id: InstanceId
model_id: ModelId
status: RunnerStatus
downloaded: bool
device_rank: int = 0
@dataclass(slots=True, frozen=True)
class PlanTestCase:
"""Table-driven description of an entire planning scenario."""
description: str
state: State
in_process_runners: List[InProcessRunner]
expected_op: Optional[RunnerOp]
def id(self) -> str: # noqa: D401
return self.description.replace(" ", "_")
def make_shard_metadata(device_rank: int, world_size: int, model_id: ModelId = MODEL_A_ID) -> PipelineShardMetadata:
"""Create PipelineShardMetadata with proper layer assignments based on device_rank and world_size."""
total_layers = world_size # For simplicity in tests, total_layers = world_size
if world_size == 1:
start_layer = 0
end_layer = 1
n_layers = 1
else:
# For multi-device setup, each device gets one layer
start_layer = device_rank
end_layer = device_rank + 1
n_layers = total_layers
return PipelineShardMetadata(
device_rank=device_rank,
world_size=world_size,
model_meta=make_model_meta(model_id),
start_layer=start_layer,
end_layer=end_layer,
n_layers=n_layers,
)
def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
"""Factory for a *Downloading* status with placeholder progress."""
return DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=node_id,
download_progress=DownloadProgressData(total_bytes=1, downloaded_bytes=0),
)
)
def make_model_meta(
model_id: str
) -> ModelMetadata:
model_card: ModelCard
for card in MODEL_CARDS.values():
if card.model_id == model_id:
model_card = card
return ModelMetadata(
model_id=model_id,
pretty_name=model_card.model_id,
storage_size_kilobytes=10**6,
n_layers=16,
)
raise Exception(f'Unknown model_id passed: {model_id}')
## Alternatively, if we are ok for this method to be async:
# await _get_model_meta(model_id)
def create_worker_state(
*,
node_id: NodeId,
runner_configs: list[tuple[RunnerId, InstanceId, ModelId]],
tmp_path: Path,
) -> State:
"""Create a test `State` based on a list of runner configurations."""
instances: dict[InstanceId, Instance] = {}
for runner_id, instance_id, model_id in runner_configs:
model_path = tmp_path / f"model_for_runner_{runner_id}"
model_path.mkdir(exist_ok=True, parents=True)
shard_metadata = PipelineShardMetadata(
device_rank=0,
world_size=1,
model_meta=make_model_meta(model_id),
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={runner_id: shard_metadata},
node_to_runner={node_id: runner_id},
)
instance = Instance(
instance_id=instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=[],
)
instances[instance_id] = instance
return State(
node_status={node_id: NodeStatus.Idle},
instances=instances,
runners={runner_id: AssignedRunnerStatus() for runner_id, _, _ in runner_configs},
tasks={},
)
def make_instance(
instance_id: InstanceId,
model_id: ModelId,
tmp_path: Path,
runner_specs: list[tuple[RunnerId, NodeId, int]],
) -> Instance:
"""Creates an instance with one or more runners."""
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
world_size = len(runner_specs)
for runner_id, node_id, device_rank in runner_specs:
model_path = tmp_path / f"model_for_runner_{runner_id}"
model_path.mkdir(exist_ok=True, parents=True)
shard_metadata = PipelineShardMetadata(
device_rank=device_rank,
world_size=world_size,
model_meta=make_model_meta(model_id),
start_layer=0,
end_layer=1,
n_layers=1,
)
runner_to_shard[runner_id] = shard_metadata
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
return Instance(
instance_id=instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=[],
)
### For worker plan tests

415
worker/worker.py Normal file
View File

@@ -0,0 +1,415 @@
import asyncio
import logging
import time
from asyncio import Queue
from functools import partial
from time import process_time
from typing import AsyncGenerator, Optional
from shared.db.sqlite import AsyncSQLiteEventStorage
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
Event,
InstanceDeleted,
RunnerDeleted,
RunnerStatusUpdated,
TaskFailed,
TaskStateUpdated,
)
from shared.types.state import State
from shared.types.tasks import TaskId, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgressData,
)
from shared.types.worker.ops import (
AssignRunnerOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerFailedOp,
RunnerOp,
RunnerOpType,
RunnerUpOp,
UnassignRunnerOp,
)
from shared.types.worker.runners import (
DownloadingRunnerStatus,
FailedRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
RunningRunnerStatus,
)
from shared.types.worker.shards import ShardMetadata
from worker.common import AssignedRunner
from worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from worker.runner.runner_supervisor import RunnerSupervisor
class Worker:
def __init__(
self,
node_id: NodeId,
logger: logging.Logger,
shard_downloader: ShardDownloader,
worker_events: AsyncSQLiteEventStorage | None,
global_events: AsyncSQLiteEventStorage | None,
):
self.node_id: NodeId = node_id
self.state: State = State()
self.shard_downloader: ShardDownloader = shard_downloader
self.worker_events: AsyncSQLiteEventStorage | None = worker_events # worker_events is None in some tests.
self.global_events: AsyncSQLiteEventStorage | None = global_events
self.logger: logging.Logger = logger
self.assigned_runners: dict[RunnerId, AssignedRunner] = {}
self._task: asyncio.Task[None] | None = None
## Op Executors
async def _execute_assign_op(
self, op: AssignRunnerOp
) -> AsyncGenerator[Event, None]:
'''
A runner has been assigned. We need to also ensure that it's downloaded.
This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state.
'''
self.assigned_runners[op.runner_id] = AssignedRunner(
runner_id=op.runner_id,
instance_id=op.instance_id,
shard_metadata=op.shard_metadata,
hosts=op.hosts,
status=DownloadingRunnerStatus(
download_progress=DownloadPending(
node_id=self.node_id
)
),
runner=None,
)
assigned_runner = self.assigned_runners[op.runner_id]
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata)
if initial_progress.status == "complete":
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadCompleted(
node_id=self.node_id
)
)
yield assigned_runner.status_update_event()
assigned_runner.status = InactiveRunnerStatus()
yield assigned_runner.status_update_event()
return
else:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=initial_progress.total_bytes,
downloaded_bytes=initial_progress.downloaded_bytes
)
)
)
yield assigned_runner.status_update_event()
# Download it!
# TODO: we probably want download progress as part of a callback that gets passed to the downloader.
download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue()
def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None:
download_progress_queue.put_nowait(progress)
self.shard_downloader.on_progress(download_progress_callback)
asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata))
# TODO: Dynamic timeout, timeout on no packet update received.
timeout_secs = 10 * 60
start_time = process_time()
last_yield_progress = start_time
while process_time() - start_time < timeout_secs:
progress: RepoDownloadProgress = await download_progress_queue.get()
if progress.status == "complete":
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadCompleted(
node_id=self.node_id,
)
)
yield assigned_runner.status_update_event()
assigned_runner.status = InactiveRunnerStatus()
yield assigned_runner.status_update_event()
break
elif progress.status == "in_progress":
if process_time() - last_yield_progress > 1:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=DownloadProgressData(
total_bytes=progress.total_bytes,
downloaded_bytes=progress.downloaded_bytes,
)
)
)
yield assigned_runner.status_update_event()
last_yield_progress = process_time()
else:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadFailed(
node_id=self.node_id,
error_message=f"Timeout downloading model: {op.shard_metadata.model_meta.model_id}"
)
)
yield assigned_runner.status_update_event()
async def _execute_unassign_op(
self, op: UnassignRunnerOp
) -> AsyncGenerator[Event, None]:
if op.runner_id not in self.assigned_runners:
return
# We can try to do a graceful shutdown of the runner.
runner: RunnerSupervisor | None = self.assigned_runners[op.runner_id].runner
if runner is not None:
await runner.astop()
# This is all we really need:
del self.assigned_runners[op.runner_id]
yield RunnerDeleted(runner_id=op.runner_id)
return
yield
async def _execute_runner_up_op(
self, op: RunnerUpOp, initialize_timeout: Optional[float] = None
) -> AsyncGenerator[Event, None]:
assigned_runner = self.assigned_runners[op.runner_id]
# TODO: This should be dynamic, based on the size of the model.
if not initialize_timeout:
gigabytes_per_second = 10
kilobytes_per_second = gigabytes_per_second * 1024 * 1024
shard = assigned_runner.shard_metadata
weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes
initialize_timeout = weights_size_kb / kilobytes_per_second + 120.0 # Add a constant 120.0 to ensure connection can be made as well
self.logger.info(f"initialize_timeout: {initialize_timeout}")
try:
assigned_runner.runner = await asyncio.wait_for(
RunnerSupervisor.create(
model_shard_meta=assigned_runner.shard_metadata,
hosts=assigned_runner.hosts,
logger=self.logger,
),
timeout=initialize_timeout,
)
except TimeoutError as e:
import traceback
tb = traceback.format_exc()
e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}")
async for event in self._fail_runner(e=e, runner_id=op.runner_id):
yield event
return
if assigned_runner.runner.healthy:
assigned_runner.status = LoadedRunnerStatus()
else:
assigned_runner.status = FailedRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
async def _execute_runner_down_op(
self, op: RunnerDownOp
) -> AsyncGenerator[Event, None]:
assigned_runner = self.assigned_runners[op.runner_id]
if isinstance(assigned_runner.runner, RunnerSupervisor):
await assigned_runner.runner.astop()
assigned_runner.runner = None
assigned_runner.status = InactiveRunnerStatus()
yield assigned_runner.status_update_event()
return
async def _execute_runner_failed_op(
self, op: RunnerFailedOp
) -> AsyncGenerator[Event, None]:
'''
We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down.
'''
assigned_runner = self.assigned_runners[op.runner_id]
assigned_runner.status = FailedRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
async def _execute_task_op(
self, op: ExecuteTaskOp
) -> AsyncGenerator[Event, None]:
'''
This is the entry point for a chat completion starting.
While there is only one execute function, it will get called in different ways for runner 0 and runner [1, 2, 3, ...].
Runners [1, 2, 3, ...] will run this method when a task is in 'pending' state.
Runner 0 will run this method when a task is in 'running' state.
TODO: How do we handle the logic of ensuring that n-1 nodes have started their execution before allowing the 0'th runner to start?
This is still a little unclear to me.
'''
assigned_runner = self.assigned_runners[op.runner_id]
async def inner_execute(queue: asyncio.Queue[Event]) -> None:
async def running_callback(queue: asyncio.Queue[Event]) -> None:
# Called when the MLX process has been kicked off
assigned_runner.status = RunningRunnerStatus()
await queue.put(assigned_runner.status_update_event())
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.RUNNING,
))
try:
assert assigned_runner.runner is not None
assert assigned_runner.runner.healthy
async for chunk in assigned_runner.runner.stream_response(
task=op.task,
request_started_callback=partial(running_callback, queue)):
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(ChunkGenerated(
# todo: at some point we will no longer have a bijection between task_id and row_id.
# So we probably want to store a mapping between these two in our Worker object.
command_id=chunk.command_id,
chunk=chunk
))
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.COMPLETE,
))
# After a successful inference:
assigned_runner.status = LoadedRunnerStatus()
await queue.put(assigned_runner.status_update_event())
except Exception as e:
# An exception occurs in the runner supervisor
self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
await queue.put(event)
queue: Queue[Event] = asyncio.Queue()
task = asyncio.create_task(inner_execute(queue))
# TODO: Initial (prefil) timeout can be dynamic
# model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes
try:
# Yield items from the queue
# timeout = 30.
timeout = 3.
while True:
item: Event = await asyncio.wait_for(queue.get(), timeout=timeout)
yield item
timeout = 2.
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus)
):
if isinstance(item.runner_status, LoadedRunnerStatus):
assigned_runner.failures = []
break
except TimeoutError as e:
# Runner supervisor doesn't respond in time; so we put the runner & task into a failed state
self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
yield event
finally:
# Ensure the task is cleaned up
try:
await asyncio.wait_for(task, timeout=5)
except asyncio.TimeoutError:
self.logger.warning("Timed out waiting for task cleanup after inference execution.")
## Operation Planner
async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
## It would be great if we can get rid of this async for ... yield pattern.
match op.op_type:
case RunnerOpType.ASSIGN_RUNNER:
event_generator = self._execute_assign_op(op)
case RunnerOpType.UNASSIGN_RUNNER:
event_generator = self._execute_unassign_op(op)
case RunnerOpType.RUNNER_UP:
event_generator = self._execute_runner_up_op(op)
case RunnerOpType.RUNNER_DOWN:
event_generator = self._execute_runner_down_op(op)
case RunnerOpType.RUNNER_FAILED:
event_generator = self._execute_runner_failed_op(op)
case RunnerOpType.CHAT_COMPLETION:
event_generator = self._execute_task_op(op)
async for event in event_generator:
yield event
async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
assigned_runner = self.assigned_runners[runner_id]
assigned_runner.runner = None
assigned_runner.status = FailedRunnerStatus(error_message=str(e))
assigned_runner.failures.append(
(
time.time(),
e
)
)
# Reset failure count back to 0 when succesful
if len(assigned_runner.failures) >= 3:
# Too many retries. We will emit a DeleteInstance
yield InstanceDeleted(
instance_id=assigned_runner.instance_id
)
yield assigned_runner.status_update_event()
async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
yield TaskStateUpdated(
task_id=task_id,
task_status=TaskStatus.FAILED,
)
yield TaskFailed(
task_id=task_id,
error_type=str(type(e)),
error_message=str(e)
)
async for event in self._fail_runner(e, runner_id):
yield event
async def event_publisher(self, event: Event) -> None:
assert self.worker_events is not None
await self.worker_events.append_events([event], self.node_id)
self.logger.info(f"published event: {event}")