mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
6 Commits
feat/bug-r
...
move-messa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc3b8392e6 | ||
|
|
4c4c6ce99f | ||
|
|
42e1e7322b | ||
|
|
aa3f106fb9 | ||
|
|
2e29605194 | ||
|
|
cacb456cb2 |
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -890,7 +890,7 @@ dependencies = [
|
|||||||
"delegate",
|
"delegate",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"extend",
|
"extend",
|
||||||
"futures",
|
"futures-lite",
|
||||||
"libp2p",
|
"libp2p",
|
||||||
"log",
|
"log",
|
||||||
"networking",
|
"networking",
|
||||||
@@ -914,6 +914,12 @@ dependencies = [
|
|||||||
"syn 2.0.111",
|
"syn 2.0.111",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastrand"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ff"
|
name = "ff"
|
||||||
version = "0.13.1"
|
version = "0.13.1"
|
||||||
@@ -1022,7 +1028,10 @@ version = "2.6.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"fastrand",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
|
"parking",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -2753,11 +2762,12 @@ dependencies = [
|
|||||||
"delegate",
|
"delegate",
|
||||||
"either",
|
"either",
|
||||||
"extend",
|
"extend",
|
||||||
"futures",
|
"futures-lite",
|
||||||
"futures-timer",
|
"futures-timer",
|
||||||
"keccak-const",
|
"keccak-const",
|
||||||
"libp2p",
|
"libp2p",
|
||||||
"log",
|
"log",
|
||||||
|
"pin-project",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"util",
|
"util",
|
||||||
|
|||||||
@@ -29,14 +29,13 @@ util = { path = "rust/util" }
|
|||||||
# Macro dependecies
|
# Macro dependecies
|
||||||
extend = "1.2"
|
extend = "1.2"
|
||||||
delegate = "0.13"
|
delegate = "0.13"
|
||||||
pin-project = "1"
|
|
||||||
|
|
||||||
# Utility dependencies
|
# Utility dependencies
|
||||||
keccak-const = "0.2"
|
keccak-const = "0.2"
|
||||||
|
|
||||||
# Async dependencies
|
# Async dependencies
|
||||||
tokio = "1.46"
|
tokio = "1.46"
|
||||||
futures = "0.3"
|
futures-lite = "2.6.1"
|
||||||
futures-timer = "3.0"
|
futures-timer = "3.0"
|
||||||
|
|
||||||
# Data structures
|
# Data structures
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -962,6 +963,21 @@ Examples:
|
|||||||
|
|
||||||
selected.sort(key=_placement_sort_key)
|
selected.sort(key=_placement_sort_key)
|
||||||
preview = selected[0]
|
preview = selected[0]
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Planning phase: checking downloads...", file=log)
|
||||||
|
run_planning_phase(
|
||||||
|
exo,
|
||||||
|
full_model_id,
|
||||||
|
preview,
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
instance = preview["instance"]
|
instance = preview["instance"]
|
||||||
instance_id = instance_id_from_instance(instance)
|
instance_id = instance_id_from_instance(instance)
|
||||||
sharding = str(preview["sharding"])
|
sharding = str(preview["sharding"])
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -332,6 +333,20 @@ def main() -> int:
|
|||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Planning phase: checking downloads...")
|
||||||
|
run_planning_phase(
|
||||||
|
client,
|
||||||
|
full_model_id,
|
||||||
|
selected[0],
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
all_rows: list[dict[str, Any]] = []
|
all_rows: list[dict[str, Any]] = []
|
||||||
|
|
||||||
for preview in selected:
|
for preview in selected:
|
||||||
|
|||||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
|||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
def run_planning_phase(
|
||||||
|
client: ExoClient,
|
||||||
|
full_model_id: str,
|
||||||
|
preview: dict[str, Any],
|
||||||
|
danger_delete: bool,
|
||||||
|
timeout: float,
|
||||||
|
settle_deadline: float | None,
|
||||||
|
) -> None:
|
||||||
|
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||||
|
# Get model size from /models
|
||||||
|
models = client.request_json("GET", "/models") or {}
|
||||||
|
model_bytes = 0
|
||||||
|
for m in models.get("data", []):
|
||||||
|
if m.get("hugging_face_id") == full_model_id:
|
||||||
|
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||||
|
break
|
||||||
|
|
||||||
|
if not model_bytes:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get nodes from preview
|
||||||
|
inner = unwrap_instance(preview["instance"])
|
||||||
|
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||||
|
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||||
|
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
|
||||||
|
for node_id in node_ids:
|
||||||
|
node_downloads = downloads.get(node_id, [])
|
||||||
|
|
||||||
|
# Check if model already downloaded on this node
|
||||||
|
already_downloaded = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
for p in node_downloads
|
||||||
|
)
|
||||||
|
if already_downloaded:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Wait for disk info if settle_deadline is set
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||||
|
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||||
|
remaining = settle_deadline - time.monotonic()
|
||||||
|
logger.info(
|
||||||
|
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||||
|
)
|
||||||
|
time.sleep(min(backoff, remaining))
|
||||||
|
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
|
||||||
|
if not disk_info:
|
||||||
|
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||||
|
continue
|
||||||
|
|
||||||
|
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||||
|
if avail >= model_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not danger_delete:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||||
|
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete from smallest to largest
|
||||||
|
completed = [
|
||||||
|
(
|
||||||
|
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
],
|
||||||
|
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||||
|
)
|
||||||
|
for p in node_downloads
|
||||||
|
if "DownloadCompleted" in p
|
||||||
|
]
|
||||||
|
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||||
|
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||||
|
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||||
|
avail += size
|
||||||
|
if avail >= model_bytes:
|
||||||
|
break
|
||||||
|
|
||||||
|
if avail < model_bytes:
|
||||||
|
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||||
|
|
||||||
|
# Start downloads (idempotent)
|
||||||
|
for node_id in node_ids:
|
||||||
|
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||||
|
shard = runner_to_shard[runner_id]
|
||||||
|
client.request_json(
|
||||||
|
"POST",
|
||||||
|
"/download/start",
|
||||||
|
body={
|
||||||
|
"targetNodeId": node_id,
|
||||||
|
"shardMetadata": shard,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(f"Started download on {node_id}")
|
||||||
|
|
||||||
|
# Wait for downloads
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
all_done = True
|
||||||
|
for node_id in node_ids:
|
||||||
|
done = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||||
|
"modelCard"
|
||||||
|
]["modelId"]
|
||||||
|
== full_model_id
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
)
|
||||||
|
failed = [
|
||||||
|
p["DownloadFailed"]["errorMessage"]
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
if "DownloadFailed" in p
|
||||||
|
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
]
|
||||||
|
if failed:
|
||||||
|
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||||
|
if not done:
|
||||||
|
all_done = False
|
||||||
|
if all_done:
|
||||||
|
return
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
raise TimeoutError("Downloads did not complete in time")
|
||||||
|
|
||||||
|
|
||||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
|||||||
default=0,
|
default=0,
|
||||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--danger-delete-downloads",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,188 +0,0 @@
|
|||||||
<script lang="ts">
|
|
||||||
import { fade, fly } from "svelte/transition";
|
|
||||||
import { cubicOut } from "svelte/easing";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
isOpen: boolean;
|
|
||||||
onClose: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
let { isOpen, onClose }: Props = $props();
|
|
||||||
|
|
||||||
let bugReportId = $state<string | null>(null);
|
|
||||||
let githubIssueUrl = $state<string | null>(null);
|
|
||||||
let isLoading = $state(false);
|
|
||||||
let error = $state<string | null>(null);
|
|
||||||
|
|
||||||
async function generateBugReport() {
|
|
||||||
isLoading = true;
|
|
||||||
error = null;
|
|
||||||
try {
|
|
||||||
const response = await fetch("/bug-report", { method: "POST" });
|
|
||||||
if (!response.ok) {
|
|
||||||
error = "Failed to generate bug report. Please try again.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const data = await response.json();
|
|
||||||
bugReportId = data.bugReportId;
|
|
||||||
githubIssueUrl = data.githubIssueUrl;
|
|
||||||
} catch {
|
|
||||||
error = "Failed to connect to the server. Please try again.";
|
|
||||||
} finally {
|
|
||||||
isLoading = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleClose() {
|
|
||||||
bugReportId = null;
|
|
||||||
githubIssueUrl = null;
|
|
||||||
error = null;
|
|
||||||
isLoading = false;
|
|
||||||
onClose();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate bug report when modal opens
|
|
||||||
$effect(() => {
|
|
||||||
if (isOpen && !bugReportId && !isLoading) {
|
|
||||||
generateBugReport();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
|
|
||||||
{#if isOpen}
|
|
||||||
<!-- Backdrop -->
|
|
||||||
<div
|
|
||||||
class="fixed inset-0 z-50 bg-black/80 backdrop-blur-sm"
|
|
||||||
transition:fade={{ duration: 200 }}
|
|
||||||
onclick={handleClose}
|
|
||||||
role="presentation"
|
|
||||||
></div>
|
|
||||||
|
|
||||||
<!-- Modal -->
|
|
||||||
<div
|
|
||||||
class="fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,480px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col"
|
|
||||||
transition:fly={{ y: 20, duration: 300, easing: cubicOut }}
|
|
||||||
role="dialog"
|
|
||||||
aria-modal="true"
|
|
||||||
aria-label="Bug Report"
|
|
||||||
>
|
|
||||||
<!-- Header -->
|
|
||||||
<div
|
|
||||||
class="flex items-center justify-between px-5 py-4 border-b border-exo-medium-gray/30"
|
|
||||||
>
|
|
||||||
<div class="flex items-center gap-2">
|
|
||||||
<svg
|
|
||||||
class="w-5 h-5 text-exo-yellow"
|
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
stroke-width="2"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<h2 class="text-sm font-mono text-exo-yellow tracking-wider uppercase">
|
|
||||||
Report a Bug
|
|
||||||
</h2>
|
|
||||||
</div>
|
|
||||||
<button
|
|
||||||
onclick={handleClose}
|
|
||||||
class="text-exo-light-gray hover:text-white transition-colors cursor-pointer"
|
|
||||||
aria-label="Close"
|
|
||||||
>
|
|
||||||
<svg
|
|
||||||
class="w-5 h-5"
|
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
stroke-width="2"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
d="M6 18L18 6M6 6l12 12"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Body -->
|
|
||||||
<div class="px-5 py-5 space-y-4">
|
|
||||||
{#if isLoading}
|
|
||||||
<div class="flex items-center justify-center py-6">
|
|
||||||
<div
|
|
||||||
class="w-5 h-5 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"
|
|
||||||
></div>
|
|
||||||
<span class="ml-3 text-sm text-exo-light-gray font-mono"
|
|
||||||
>Generating bug report...</span
|
|
||||||
>
|
|
||||||
</div>
|
|
||||||
{:else if error}
|
|
||||||
<div
|
|
||||||
class="text-sm text-red-400 font-mono bg-red-400/10 border border-red-400/20 rounded px-4 py-3"
|
|
||||||
>
|
|
||||||
{error}
|
|
||||||
</div>
|
|
||||||
<button
|
|
||||||
onclick={generateBugReport}
|
|
||||||
class="w-full px-4 py-2.5 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded text-sm font-mono text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer"
|
|
||||||
>
|
|
||||||
Try Again
|
|
||||||
</button>
|
|
||||||
{:else if bugReportId && githubIssueUrl}
|
|
||||||
<p class="text-sm text-exo-light-gray leading-relaxed">
|
|
||||||
Would you like to create a GitHub issue? This would help us track and
|
|
||||||
fix the issue for you.
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<!-- Bug Report ID -->
|
|
||||||
<div
|
|
||||||
class="bg-exo-black/50 border border-exo-medium-gray/30 rounded px-4 py-3"
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
class="text-[11px] text-exo-light-gray/60 font-mono tracking-wider uppercase mb-1"
|
|
||||||
>
|
|
||||||
Bug Report ID
|
|
||||||
</div>
|
|
||||||
<div class="text-sm text-exo-yellow font-mono tracking-wide">
|
|
||||||
{bugReportId}
|
|
||||||
</div>
|
|
||||||
<div class="text-[11px] text-exo-light-gray/50 font-mono mt-1">
|
|
||||||
Include this ID when communicating with the team.
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p class="text-xs text-exo-light-gray/60 leading-relaxed">
|
|
||||||
No diagnostic data is attached. The issue template contains
|
|
||||||
placeholder fields for you to fill in.
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<!-- Actions -->
|
|
||||||
<div class="flex gap-3 pt-1">
|
|
||||||
<a
|
|
||||||
href={githubIssueUrl}
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
class="flex-1 flex items-center justify-center gap-2 px-4 py-2.5 bg-exo-yellow/10 border border-exo-yellow/40 rounded text-sm font-mono text-exo-yellow hover:bg-exo-yellow/20 hover:border-exo-yellow/60 transition-colors"
|
|
||||||
>
|
|
||||||
<svg class="w-4 h-4" viewBox="0 0 16 16" fill="currentColor">
|
|
||||||
<path
|
|
||||||
d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
Create GitHub Issue
|
|
||||||
</a>
|
|
||||||
<button
|
|
||||||
onclick={handleClose}
|
|
||||||
class="px-4 py-2.5 border border-exo-medium-gray/40 rounded text-sm font-mono text-exo-light-gray hover:border-exo-medium-gray/60 transition-colors cursor-pointer"
|
|
||||||
>
|
|
||||||
Close
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
@@ -74,7 +74,6 @@
|
|||||||
perSystem =
|
perSystem =
|
||||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||||
let
|
let
|
||||||
fenixToolchain = inputs'.fenix.packages.complete;
|
|
||||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||||
in
|
in
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
|
||||||
#allowed-duplicate-crates = ["hashbrown"]
|
|
||||||
@@ -27,7 +27,7 @@ networking = { workspace = true }
|
|||||||
# interop
|
# interop
|
||||||
pyo3 = { version = "0.27.2", features = [
|
pyo3 = { version = "0.27.2", features = [
|
||||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||||
"nightly", # enables better-supported GIL integration
|
# "nightly", # enables better-supported GIL integration
|
||||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||||
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
|
|||||||
# macro dependencies
|
# macro dependencies
|
||||||
extend = { workspace = true }
|
extend = { workspace = true }
|
||||||
delegate = { workspace = true }
|
delegate = { workspace = true }
|
||||||
pin-project = { workspace = true }
|
|
||||||
|
|
||||||
# async runtime
|
# async runtime
|
||||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||||
futures = { workspace = true }
|
futures-lite = { workspace = true }
|
||||||
|
|
||||||
# utility dependencies
|
# utility dependencies
|
||||||
util = { workspace = true }
|
util = { workspace = true }
|
||||||
@@ -60,3 +59,4 @@ env_logger = "0.11"
|
|||||||
|
|
||||||
# Networking
|
# Networking
|
||||||
libp2p = { workspace = true, features = ["full"] }
|
libp2p = { workspace = true, features = ["full"] }
|
||||||
|
pin-project = "1.1.10"
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# ruff: noqa: E501, F401
|
# ruff: noqa: E501, F401
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import enum
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
@@ -11,138 +10,33 @@ class AllQueuesFullError(builtins.Exception):
|
|||||||
def __repr__(self) -> builtins.str: ...
|
def __repr__(self) -> builtins.str: ...
|
||||||
def __str__(self) -> builtins.str: ...
|
def __str__(self) -> builtins.str: ...
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class ConnectionUpdate:
|
|
||||||
@property
|
|
||||||
def update_type(self) -> ConnectionUpdateType:
|
|
||||||
r"""
|
|
||||||
Whether this is a connection or disconnection event
|
|
||||||
"""
|
|
||||||
@property
|
|
||||||
def peer_id(self) -> PeerId:
|
|
||||||
r"""
|
|
||||||
Identity of the peer that we have connected to or disconnected from.
|
|
||||||
"""
|
|
||||||
@property
|
|
||||||
def remote_ipv4(self) -> builtins.str:
|
|
||||||
r"""
|
|
||||||
Remote connection's IPv4 address.
|
|
||||||
"""
|
|
||||||
@property
|
|
||||||
def remote_tcp_port(self) -> builtins.int:
|
|
||||||
r"""
|
|
||||||
Remote connection's TCP port.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class Keypair:
|
class Keypair:
|
||||||
r"""
|
r"""
|
||||||
Identity keypair of a node.
|
Identity keypair of a node.
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_ed25519() -> Keypair:
|
def generate() -> Keypair:
|
||||||
r"""
|
r"""
|
||||||
Generate a new Ed25519 keypair.
|
Generate a new Ed25519 keypair.
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_ecdsa() -> Keypair:
|
def from_bytes(bytes: bytes) -> Keypair:
|
||||||
r"""
|
r"""
|
||||||
Generate a new ECDSA keypair.
|
Construct an Ed25519 keypair from secret key bytes
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def generate_secp256k1() -> Keypair:
|
|
||||||
r"""
|
|
||||||
Generate a new Secp256k1 keypair.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
|
||||||
format (i.e. unencrypted) as defined in [RFC5208].
|
|
||||||
|
|
||||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
|
||||||
structure as defined in [RFC5915].
|
|
||||||
|
|
||||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
|
||||||
def to_protobuf_encoding(self) -> bytes:
|
|
||||||
r"""
|
|
||||||
Encode a private key as protobuf structure.
|
|
||||||
"""
|
|
||||||
def to_peer_id(self) -> PeerId:
|
|
||||||
r"""
|
|
||||||
Convert the `Keypair` into the corresponding `PeerId`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class Multiaddr:
|
|
||||||
r"""
|
|
||||||
Representation of a Multiaddr.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def empty() -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Create a new, empty multiaddress.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Create a new, empty multiaddress with the given capacity.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Parse a `Multiaddr` value from its byte slice representation.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_string(string: builtins.str) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Parse a `Multiaddr` value from its string representation.
|
|
||||||
"""
|
|
||||||
def len(self) -> builtins.int:
|
|
||||||
r"""
|
|
||||||
Return the length in bytes of this multiaddress.
|
|
||||||
"""
|
|
||||||
def is_empty(self) -> builtins.bool:
|
|
||||||
r"""
|
|
||||||
Returns true if the length of this multiaddress is 0.
|
|
||||||
"""
|
"""
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
r"""
|
r"""
|
||||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
Get the secret key bytes underlying the keypair
|
||||||
"""
|
"""
|
||||||
def to_string(self) -> builtins.str:
|
def to_node_id(self) -> builtins.str:
|
||||||
r"""
|
r"""
|
||||||
Convert a Multiaddr to a string.
|
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class NetworkingHandle:
|
class NetworkingHandle:
|
||||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
|
||||||
r"""
|
|
||||||
Receives the next `ConnectionUpdate` from networking.
|
|
||||||
"""
|
|
||||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
|
||||||
r"""
|
|
||||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
|
||||||
|
|
||||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
|
||||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
|
||||||
will sleep until a `ConnectionUpdate`s is sent.
|
|
||||||
"""
|
|
||||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||||
r"""
|
r"""
|
||||||
Subscribe to a `GossipSub` topic.
|
Subscribe to a `GossipSub` topic.
|
||||||
@@ -161,18 +55,7 @@ class NetworkingHandle:
|
|||||||
|
|
||||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||||
"""
|
"""
|
||||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
async def recv(self) -> PyFromSwarm: ...
|
||||||
r"""
|
|
||||||
Receives the next message from the `GossipSub` network.
|
|
||||||
"""
|
|
||||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
|
||||||
r"""
|
|
||||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
|
||||||
|
|
||||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
|
||||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
|
||||||
will sleep until a message is sent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||||
@@ -180,42 +63,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
|||||||
def __repr__(self) -> builtins.str: ...
|
def __repr__(self) -> builtins.str: ...
|
||||||
def __str__(self) -> builtins.str: ...
|
def __str__(self) -> builtins.str: ...
|
||||||
|
|
||||||
@typing.final
|
class PyFromSwarm:
|
||||||
class PeerId:
|
@typing.final
|
||||||
r"""
|
class Connection(PyFromSwarm):
|
||||||
Identifier of a peer of the network.
|
__match_args__ = ("peer_id", "connected",)
|
||||||
|
@property
|
||||||
|
def peer_id(self) -> builtins.str: ...
|
||||||
|
@property
|
||||||
|
def connected(self) -> builtins.bool: ...
|
||||||
|
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
|
||||||
|
|
||||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
@typing.final
|
||||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
class Message(PyFromSwarm):
|
||||||
"""
|
__match_args__ = ("origin", "topic", "data",)
|
||||||
@staticmethod
|
@property
|
||||||
def random() -> PeerId:
|
def origin(self) -> builtins.str: ...
|
||||||
r"""
|
@property
|
||||||
Generates a random peer ID from a cryptographically secure PRNG.
|
def topic(self) -> builtins.str: ...
|
||||||
|
@property
|
||||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
def data(self) -> bytes: ...
|
||||||
"""
|
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
|
||||||
@staticmethod
|
|
||||||
def from_bytes(bytes: bytes) -> PeerId:
|
...
|
||||||
r"""
|
|
||||||
Parses a `PeerId` from bytes.
|
|
||||||
"""
|
|
||||||
def to_bytes(self) -> bytes:
|
|
||||||
r"""
|
|
||||||
Returns a raw bytes representation of this `PeerId`.
|
|
||||||
"""
|
|
||||||
def to_base58(self) -> builtins.str:
|
|
||||||
r"""
|
|
||||||
Returns a base-58 encoded string of this `PeerId`.
|
|
||||||
"""
|
|
||||||
def __repr__(self) -> builtins.str: ...
|
|
||||||
def __str__(self) -> builtins.str: ...
|
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class ConnectionUpdateType(enum.Enum):
|
|
||||||
r"""
|
|
||||||
Connection or disconnection event discriminant type.
|
|
||||||
"""
|
|
||||||
Connected = ...
|
|
||||||
Disconnected = ...
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
//!
|
//!
|
||||||
|
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use pyo3::marker::Ungil;
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
@@ -26,8 +25,8 @@ where
|
|||||||
|
|
||||||
impl<F> Future for AllowThreads<F>
|
impl<F> Future for AllowThreads<F>
|
||||||
where
|
where
|
||||||
F: Future + Ungil,
|
F: Future + Send,
|
||||||
F::Output: Ungil,
|
F::Output: Send,
|
||||||
{
|
{
|
||||||
type Output = F::Output;
|
type Output = F::Output;
|
||||||
|
|
||||||
|
|||||||
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
use crate::ext::ResultExt as _;
|
||||||
|
use libp2p::identity::Keypair;
|
||||||
|
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||||
|
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||||
|
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||||
|
|
||||||
|
/// Identity keypair of a node.
|
||||||
|
#[gen_stub_pyclass]
|
||||||
|
#[pyclass(name = "Keypair", frozen)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
pub struct PyKeypair(pub Keypair);
|
||||||
|
|
||||||
|
#[gen_stub_pymethods]
|
||||||
|
#[pymethods]
|
||||||
|
#[allow(clippy::needless_pass_by_value)]
|
||||||
|
impl PyKeypair {
|
||||||
|
/// Generate a new Ed25519 keypair.
|
||||||
|
#[staticmethod]
|
||||||
|
fn generate() -> Self {
|
||||||
|
Self(Keypair::generate_ed25519())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct an Ed25519 keypair from secret key bytes
|
||||||
|
#[staticmethod]
|
||||||
|
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||||
|
let mut bytes = Vec::from(bytes.as_bytes());
|
||||||
|
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the secret key bytes underlying the keypair
|
||||||
|
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||||
|
let bytes = self
|
||||||
|
.0
|
||||||
|
.clone()
|
||||||
|
.try_into_ed25519()
|
||||||
|
.pyerr()?
|
||||||
|
.secret()
|
||||||
|
.as_ref()
|
||||||
|
.to_vec();
|
||||||
|
Ok(PyBytes::new(py, &bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||||
|
fn to_node_id(&self) -> String {
|
||||||
|
self.0.public().to_peer_id().to_base58()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,26 +4,14 @@
|
|||||||
//!
|
//!
|
||||||
//!
|
//!
|
||||||
|
|
||||||
// enable Rust-unstable features for convenience
|
|
||||||
#![feature(trait_alias)]
|
|
||||||
#![feature(tuple_trait)]
|
|
||||||
#![feature(unboxed_closures)]
|
|
||||||
// #![feature(stmt_expr_attributes)]
|
|
||||||
// #![feature(assert_matches)]
|
|
||||||
// #![feature(async_fn_in_dyn_trait)]
|
|
||||||
// #![feature(async_for_loop)]
|
|
||||||
// #![feature(auto_traits)]
|
|
||||||
// #![feature(negative_impls)]
|
|
||||||
|
|
||||||
extern crate core;
|
|
||||||
mod allow_threading;
|
mod allow_threading;
|
||||||
pub(crate) mod networking;
|
mod ident;
|
||||||
pub(crate) mod pylibp2p;
|
mod networking;
|
||||||
|
|
||||||
|
use crate::ident::PyKeypair;
|
||||||
use crate::networking::networking_submodule;
|
use crate::networking::networking_submodule;
|
||||||
use crate::pylibp2p::ident::ident_submodule;
|
|
||||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
|
||||||
use pyo3::prelude::PyModule;
|
use pyo3::prelude::PyModule;
|
||||||
|
use pyo3::types::PyModuleMethods;
|
||||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||||
|
|
||||||
@@ -32,14 +20,6 @@ pub(crate) mod r#const {
|
|||||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Namespace for all the type/trait aliases used by this crate.
|
|
||||||
pub(crate) mod alias {
|
|
||||||
use std::marker::Tuple;
|
|
||||||
|
|
||||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
|
||||||
Fn<Args, Output = Output> + Send + 'static;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Namespace for crate-wide extension traits/methods
|
/// Namespace for crate-wide extension traits/methods
|
||||||
pub(crate) mod ext {
|
pub(crate) mod ext {
|
||||||
use crate::allow_threading::AllowThreads;
|
use crate::allow_threading::AllowThreads;
|
||||||
@@ -175,12 +155,14 @@ pub(crate) mod ext {
|
|||||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
// install logger
|
// install logger
|
||||||
pyo3_log::init();
|
pyo3_log::init();
|
||||||
|
let mut builder = tokio::runtime::Builder::new_multi_thread();
|
||||||
|
builder.enable_all();
|
||||||
|
pyo3_async_runtimes::tokio::init(builder);
|
||||||
|
|
||||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||||
// work with maturin, where the types generate correctly, in the right folder, without
|
// work with maturin, where the types generate correctly, in the right folder, without
|
||||||
// too many importing issues...
|
// too many importing issues...
|
||||||
ident_submodule(m)?;
|
m.add_class::<PyKeypair>()?;
|
||||||
multiaddr_submodule(m)?;
|
|
||||||
networking_submodule(m)?;
|
networking_submodule(m)?;
|
||||||
|
|
||||||
// top-level constructs
|
// top-level constructs
|
||||||
|
|||||||
@@ -1,26 +1,21 @@
|
|||||||
#![allow(
|
use std::sync::Arc;
|
||||||
clippy::multiple_inherent_impl,
|
|
||||||
clippy::unnecessary_wraps,
|
|
||||||
clippy::unused_self,
|
|
||||||
clippy::needless_pass_by_value
|
|
||||||
)]
|
|
||||||
|
|
||||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
|
||||||
|
use crate::ident::PyKeypair;
|
||||||
|
use crate::networking::exception::{PyAllQueuesFullError, PyNoPeersSubscribedToTopicError};
|
||||||
use crate::pyclass;
|
use crate::pyclass;
|
||||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
use futures_lite::StreamExt as _;
|
||||||
use libp2p::futures::StreamExt as _;
|
use libp2p::gossipsub::PublishError;
|
||||||
use libp2p::gossipsub;
|
use networking::swarm::{FromSwarm, Swarm, ToSwarm, create_swarm};
|
||||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
use pyo3::exceptions::PyRuntimeError;
|
||||||
use libp2p::swarm::SwarmEvent;
|
|
||||||
use networking::discovery;
|
|
||||||
use networking::swarm::create_swarm;
|
|
||||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||||
use pyo3::types::PyBytes;
|
use pyo3::types::PyBytes;
|
||||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
|
||||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
use pyo3_stub_gen::derive::{
|
||||||
use std::net::IpAddr;
|
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
|
||||||
|
};
|
||||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||||
|
|
||||||
mod exception {
|
mod exception {
|
||||||
@@ -100,235 +95,45 @@ mod exception {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Connection or disconnection event discriminant type.
|
|
||||||
#[gen_stub_pyclass_enum]
|
|
||||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
enum PyConnectionUpdateType {
|
|
||||||
Connected = 0,
|
|
||||||
Disconnected,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gen_stub_pyclass]
|
|
||||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct PyConnectionUpdate {
|
|
||||||
/// Whether this is a connection or disconnection event
|
|
||||||
#[pyo3(get)]
|
|
||||||
update_type: PyConnectionUpdateType,
|
|
||||||
|
|
||||||
/// Identity of the peer that we have connected to or disconnected from.
|
|
||||||
#[pyo3(get)]
|
|
||||||
peer_id: PyPeerId,
|
|
||||||
|
|
||||||
/// Remote connection's IPv4 address.
|
|
||||||
#[pyo3(get)]
|
|
||||||
remote_ipv4: String,
|
|
||||||
|
|
||||||
/// Remote connection's TCP port.
|
|
||||||
#[pyo3(get)]
|
|
||||||
remote_tcp_port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum ToTask {
|
|
||||||
GossipsubSubscribe {
|
|
||||||
topic: String,
|
|
||||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
|
||||||
},
|
|
||||||
GossipsubUnsubscribe {
|
|
||||||
topic: String,
|
|
||||||
result_tx: oneshot::Sender<bool>,
|
|
||||||
},
|
|
||||||
GossipsubPublish {
|
|
||||||
topic: String,
|
|
||||||
data: Vec<u8>,
|
|
||||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::enum_glob_use)]
|
|
||||||
async fn networking_task(
|
|
||||||
mut swarm: networking::swarm::Swarm,
|
|
||||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
|
||||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
|
||||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
|
||||||
) {
|
|
||||||
use SwarmEvent::*;
|
|
||||||
use ToTask::*;
|
|
||||||
use networking::swarm::BehaviourEvent::*;
|
|
||||||
|
|
||||||
log::info!("RUST: networking task started");
|
|
||||||
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
message = to_task_rx.recv() => {
|
|
||||||
// handle closed channel
|
|
||||||
let Some(message) = message else {
|
|
||||||
log::info!("RUST: channel closed");
|
|
||||||
break;
|
|
||||||
};
|
|
||||||
|
|
||||||
// dispatch incoming messages
|
|
||||||
match message {
|
|
||||||
GossipsubSubscribe { topic, result_tx } => {
|
|
||||||
// try to subscribe
|
|
||||||
let result = swarm.behaviour_mut()
|
|
||||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
|
||||||
|
|
||||||
// send response oneshot
|
|
||||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
|
||||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GossipsubUnsubscribe { topic, result_tx } => {
|
|
||||||
// try to unsubscribe from the topic
|
|
||||||
let result = swarm.behaviour_mut()
|
|
||||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
|
||||||
|
|
||||||
// send response oneshot (or exit if connection closed)
|
|
||||||
if let Err(e) = result_tx.send(result) {
|
|
||||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GossipsubPublish { topic, data, result_tx } => {
|
|
||||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
|
||||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
|
||||||
IdentTopic::new(topic), data);
|
|
||||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
|
||||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
|
||||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
|
||||||
Err(exception::PyAllQueuesFullError::new_err())
|
|
||||||
} else {
|
|
||||||
result.pyerr()
|
|
||||||
};
|
|
||||||
|
|
||||||
// send response oneshot (or exit if connection closed)
|
|
||||||
if let Err(e) = result_tx.send(pyresult) {
|
|
||||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// architectural solution to this problem:
|
|
||||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
|
||||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
|
||||||
//
|
|
||||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
|
||||||
// then for actual communication it will dial those peers if need-be
|
|
||||||
swarm_event = swarm.select_next_some() => {
|
|
||||||
match swarm_event {
|
|
||||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
|
||||||
message: Message {
|
|
||||||
topic,
|
|
||||||
data,
|
|
||||||
..
|
|
||||||
},
|
|
||||||
..
|
|
||||||
})) => {
|
|
||||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
|
||||||
let message = (topic.into_string(), data);
|
|
||||||
|
|
||||||
// send incoming message to channel (or exit if connection closed)
|
|
||||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
|
||||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
|
||||||
// grab IPv4 string
|
|
||||||
let remote_ipv4 = match remote_ip {
|
|
||||||
IpAddr::V4(ip) => ip.to_string(),
|
|
||||||
IpAddr::V6(ip) => {
|
|
||||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// send connection event to channel (or exit if connection closed)
|
|
||||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
|
||||||
update_type: PyConnectionUpdateType::Connected,
|
|
||||||
peer_id: PyPeerId(peer_id),
|
|
||||||
remote_ipv4,
|
|
||||||
remote_tcp_port,
|
|
||||||
}).await {
|
|
||||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
|
||||||
// grab IPv4 string
|
|
||||||
let remote_ipv4 = match remote_ip {
|
|
||||||
IpAddr::V4(ip) => ip.to_string(),
|
|
||||||
IpAddr::V6(ip) => {
|
|
||||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// send disconnection event to channel (or exit if connection closed)
|
|
||||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
|
||||||
update_type: PyConnectionUpdateType::Disconnected,
|
|
||||||
peer_id: PyPeerId(peer_id),
|
|
||||||
remote_ipv4,
|
|
||||||
remote_tcp_port,
|
|
||||||
}).await {
|
|
||||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
e => {
|
|
||||||
log::info!("RUST: other event {e:?}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log::info!("RUST: networking task stopped");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gen_stub_pyclass]
|
#[gen_stub_pyclass]
|
||||||
#[pyclass(name = "NetworkingHandle")]
|
#[pyclass(name = "NetworkingHandle")]
|
||||||
#[derive(Debug)]
|
|
||||||
struct PyNetworkingHandle {
|
struct PyNetworkingHandle {
|
||||||
// channels
|
// channels
|
||||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
pub to_swarm: mpsc::Sender<ToSwarm>,
|
||||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
pub swarm: Arc<Mutex<Swarm>>,
|
||||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for PyNetworkingHandle {
|
#[gen_stub_pyclass_complex_enum]
|
||||||
fn drop(&mut self) {
|
#[pyclass]
|
||||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
enum PyFromSwarm {
|
||||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
Connection {
|
||||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
peer_id: String,
|
||||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
connected: bool,
|
||||||
}
|
},
|
||||||
|
Message {
|
||||||
|
origin: String,
|
||||||
|
topic: String,
|
||||||
|
data: Py<PyBytes>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
impl From<FromSwarm> for PyFromSwarm {
|
||||||
#[allow(clippy::expect_used)]
|
fn from(value: FromSwarm) -> Self {
|
||||||
impl PyNetworkingHandle {
|
match value {
|
||||||
fn new(
|
FromSwarm::Discovered { peer_id } => Self::Connection {
|
||||||
to_task_tx: mpsc::Sender<ToTask>,
|
peer_id: peer_id.to_base58(),
|
||||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
connected: true,
|
||||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
},
|
||||||
) -> Self {
|
FromSwarm::Expired { peer_id } => Self::Connection {
|
||||||
Self {
|
peer_id: peer_id.to_base58(),
|
||||||
to_task_tx: Some(to_task_tx),
|
connected: false,
|
||||||
connection_update_rx: Mutex::new(connection_update_rx),
|
},
|
||||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
FromSwarm::Message { from, topic, data } => Self::Message {
|
||||||
|
origin: from.to_base58(),
|
||||||
|
topic: topic,
|
||||||
|
data: data.pybytes(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
|
||||||
self.to_task_tx
|
|
||||||
.as_ref()
|
|
||||||
.expect("The sender should only be None after de-initialization.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gen_stub_pymethods]
|
#[gen_stub_pymethods]
|
||||||
@@ -342,97 +147,36 @@ impl PyNetworkingHandle {
|
|||||||
|
|
||||||
#[new]
|
#[new]
|
||||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||||
use pyo3_async_runtimes::tokio::get_runtime;
|
|
||||||
|
|
||||||
// create communication channels
|
// create communication channels
|
||||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
|
||||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
|
||||||
|
|
||||||
// get identity
|
// get identity
|
||||||
let identity = identity.borrow().0.clone();
|
let identity = identity.borrow().0.clone();
|
||||||
|
|
||||||
// create networking swarm (within tokio context!! or it crashes)
|
// create networking swarm (within tokio context!! or it crashes)
|
||||||
let swarm = get_runtime()
|
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
|
||||||
.block_on(async { create_swarm(identity) })
|
let swarm = { create_swarm(identity, from_client).pyerr()? };
|
||||||
.pyerr()?;
|
|
||||||
|
|
||||||
// spawn tokio task running the networking logic
|
Ok(Self {
|
||||||
get_runtime().spawn(async move {
|
swarm: Arc::new(Mutex::new(swarm)),
|
||||||
networking_task(
|
to_swarm,
|
||||||
swarm,
|
})
|
||||||
to_task_rx,
|
|
||||||
connection_update_tx,
|
|
||||||
gossipsub_message_tx,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
});
|
|
||||||
Ok(Self::new(
|
|
||||||
to_task_tx,
|
|
||||||
connection_update_rx,
|
|
||||||
gossipsub_message_rx,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gen_stub(skip)]
|
#[gen_stub(skip)]
|
||||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||||
Ok(()) // This is needed purely so `__clear__` can work
|
let swarm = Arc::clone(&self.swarm);
|
||||||
|
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||||
|
swarm
|
||||||
|
.try_lock()
|
||||||
|
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.ok_or(PyErr::receiver_channel_closed())
|
||||||
|
.map(PyFromSwarm::from)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gen_stub(skip)]
|
|
||||||
fn __clear__(&mut self) {
|
|
||||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
|
||||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
|
||||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
|
||||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- Connection update receiver methods ----
|
|
||||||
|
|
||||||
/// Receives the next `ConnectionUpdate` from networking.
|
|
||||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
|
||||||
self.connection_update_rx
|
|
||||||
.lock()
|
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await
|
|
||||||
.recv_py()
|
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
|
||||||
///
|
|
||||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
|
||||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
|
||||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
|
||||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
|
||||||
self.connection_update_rx
|
|
||||||
.lock()
|
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await
|
|
||||||
.recv_many_py(limit)
|
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
|
||||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
|
||||||
// so things don't randomly block
|
|
||||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
|
||||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
|
||||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
|
||||||
// fn connection_update_is_empty(&self) -> bool {
|
|
||||||
// self.connection_update_rx.blocking_lock().is_empty()
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
|
||||||
// fn connection_update_len(&self) -> usize {
|
|
||||||
// self.connection_update_rx.blocking_lock().len()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// ---- Gossipsub management methods ----
|
// ---- Gossipsub management methods ----
|
||||||
|
|
||||||
/// Subscribe to a `GossipSub` topic.
|
/// Subscribe to a `GossipSub` topic.
|
||||||
@@ -442,10 +186,10 @@ impl PyNetworkingHandle {
|
|||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
|
|
||||||
// send off request to subscribe
|
// send off request to subscribe
|
||||||
self.to_task_tx()
|
self.to_swarm
|
||||||
.send_py(ToTask::GossipsubSubscribe {
|
.send_py(ToSwarm::Subscribe {
|
||||||
topic,
|
topic,
|
||||||
result_tx: tx,
|
result_sender: tx,
|
||||||
})
|
})
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
.allow_threads_py() // allow-threads-aware async call
|
||||||
.await?;
|
.await?;
|
||||||
@@ -454,6 +198,7 @@ impl PyNetworkingHandle {
|
|||||||
rx.allow_threads_py() // allow-threads-aware async call
|
rx.allow_threads_py() // allow-threads-aware async call
|
||||||
.await
|
.await
|
||||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||||
|
.pyerr()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unsubscribes from a `GossipSub` topic.
|
/// Unsubscribes from a `GossipSub` topic.
|
||||||
@@ -463,10 +208,10 @@ impl PyNetworkingHandle {
|
|||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
|
|
||||||
// send off request to unsubscribe
|
// send off request to unsubscribe
|
||||||
self.to_task_tx()
|
self.to_swarm
|
||||||
.send_py(ToTask::GossipsubUnsubscribe {
|
.send_py(ToSwarm::Unsubscribe {
|
||||||
topic,
|
topic,
|
||||||
result_tx: tx,
|
result_sender: tx,
|
||||||
})
|
})
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
.allow_threads_py() // allow-threads-aware async call
|
||||||
.await?;
|
.await?;
|
||||||
@@ -485,11 +230,11 @@ impl PyNetworkingHandle {
|
|||||||
|
|
||||||
// send off request to subscribe
|
// send off request to subscribe
|
||||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||||
self.to_task_tx()
|
self.to_swarm
|
||||||
.send_py(ToTask::GossipsubPublish {
|
.send_py(ToSwarm::Publish {
|
||||||
topic,
|
topic,
|
||||||
data,
|
data,
|
||||||
result_tx: tx,
|
result_sender: tx,
|
||||||
})
|
})
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
.allow_threads_py() // allow-threads-aware async call
|
||||||
.await?;
|
.await?;
|
||||||
@@ -498,74 +243,33 @@ impl PyNetworkingHandle {
|
|||||||
let _ = rx
|
let _ = rx
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
.allow_threads_py() // allow-threads-aware async call
|
||||||
.await
|
.await
|
||||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||||
|
.map_err(|e| match e {
|
||||||
|
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
|
||||||
|
PublishError::NoPeersSubscribedToTopic => {
|
||||||
|
PyNoPeersSubscribedToTopicError::new_err()
|
||||||
|
}
|
||||||
|
e => PyRuntimeError::new_err(e.to_string()),
|
||||||
|
})?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ---- Gossipsub message receiver methods ----
|
pyo3_stub_gen::inventory::submit! {
|
||||||
|
gen_methods_from_python! {
|
||||||
/// Receives the next message from the `GossipSub` network.
|
r#"
|
||||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
class PyNetworkingHandle:
|
||||||
self.gossipsub_message_rx
|
async def recv() -> PyFromSwarm: ...
|
||||||
.lock()
|
"#
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await
|
|
||||||
.recv_py()
|
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await
|
|
||||||
.map(|(t, d)| (t, d.pybytes()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
|
||||||
///
|
|
||||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
|
||||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
|
||||||
/// will sleep until a message is sent.
|
|
||||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
|
||||||
Ok(self
|
|
||||||
.gossipsub_message_rx
|
|
||||||
.lock()
|
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await
|
|
||||||
.recv_many_py(limit)
|
|
||||||
.allow_threads_py() // allow-threads-aware async call
|
|
||||||
.await?
|
|
||||||
.into_iter()
|
|
||||||
.map(|(t, d)| (t, d.pybytes()))
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
|
||||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
|
||||||
// so things don't randomly block
|
|
||||||
// /// Tries to receive the next message from the `GossipSub` network.
|
|
||||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
|
||||||
// Ok(self
|
|
||||||
// .gossipsub_message_rx
|
|
||||||
// .blocking_lock()
|
|
||||||
// .try_recv_py()?
|
|
||||||
// .map(|(t, d)| (t, d.pybytes())))
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /// Checks if the `GossipSub` message channel is empty.
|
|
||||||
// fn gossipsub_is_empty(&self) -> bool {
|
|
||||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /// Returns the number of `GossipSub` messages in the channel.
|
|
||||||
// fn gossipsub_len(&self) -> usize {
|
|
||||||
// self.gossipsub_message_rx.blocking_lock().len()
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||||
|
|
||||||
m.add_class::<PyConnectionUpdateType>()?;
|
|
||||||
m.add_class::<PyConnectionUpdate>()?;
|
|
||||||
m.add_class::<PyConnectionUpdateType>()?;
|
|
||||||
m.add_class::<PyNetworkingHandle>()?;
|
m.add_class::<PyNetworkingHandle>()?;
|
||||||
|
m.add_class::<PyFromSwarm>()?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,159 +0,0 @@
|
|||||||
use crate::ext::ResultExt as _;
|
|
||||||
use libp2p::PeerId;
|
|
||||||
use libp2p::identity::Keypair;
|
|
||||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
|
||||||
use pyo3::types::PyBytes;
|
|
||||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
|
||||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
|
||||||
|
|
||||||
/// Identity keypair of a node.
|
|
||||||
#[gen_stub_pyclass]
|
|
||||||
#[pyclass(name = "Keypair", frozen)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct PyKeypair(pub Keypair);
|
|
||||||
|
|
||||||
#[gen_stub_pymethods]
|
|
||||||
#[pymethods]
|
|
||||||
#[allow(clippy::needless_pass_by_value)]
|
|
||||||
impl PyKeypair {
|
|
||||||
/// Generate a new Ed25519 keypair.
|
|
||||||
#[staticmethod]
|
|
||||||
fn generate_ed25519() -> Self {
|
|
||||||
Self(Keypair::generate_ed25519())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a new ECDSA keypair.
|
|
||||||
#[staticmethod]
|
|
||||||
fn generate_ecdsa() -> Self {
|
|
||||||
Self(Keypair::generate_ecdsa())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a new Secp256k1 keypair.
|
|
||||||
#[staticmethod]
|
|
||||||
fn generate_secp256k1() -> Self {
|
|
||||||
Self(Keypair::generate_secp256k1())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
|
||||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
|
||||||
///
|
|
||||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
|
||||||
#[staticmethod]
|
|
||||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
|
||||||
/// structure as defined in [RFC5915].
|
|
||||||
///
|
|
||||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
|
||||||
#[staticmethod]
|
|
||||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[staticmethod]
|
|
||||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Encode a private key as protobuf structure.
|
|
||||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
|
||||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
|
||||||
Ok(PyBytes::new(py, &bytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
|
||||||
fn to_peer_id(&self) -> PyPeerId {
|
|
||||||
PyPeerId(self.0.public().to_peer_id())
|
|
||||||
}
|
|
||||||
|
|
||||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// #[new]
|
|
||||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
// Self::from_protobuf_encoding(bytes)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
|
||||||
// *self = Self::from_protobuf_encoding(state)?;
|
|
||||||
// Ok(())
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
|
||||||
// self.to_protobuf_encoding(py)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
|
||||||
// Ok((self.to_protobuf_encoding(py)?,))
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Identifier of a peer of the network.
|
|
||||||
///
|
|
||||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
|
||||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
|
||||||
#[gen_stub_pyclass]
|
|
||||||
#[pyclass(name = "PeerId", frozen)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct PyPeerId(pub PeerId);
|
|
||||||
|
|
||||||
#[gen_stub_pymethods]
|
|
||||||
#[pymethods]
|
|
||||||
#[allow(clippy::needless_pass_by_value)]
|
|
||||||
impl PyPeerId {
|
|
||||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
|
||||||
///
|
|
||||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
|
||||||
#[staticmethod]
|
|
||||||
fn random() -> Self {
|
|
||||||
Self(PeerId::random())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses a `PeerId` from bytes.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a raw bytes representation of this `PeerId`.
|
|
||||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
|
||||||
let bytes = self.0.to_bytes();
|
|
||||||
PyBytes::new(py, &bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a base-58 encoded string of this `PeerId`.
|
|
||||||
fn to_base58(&self) -> String {
|
|
||||||
self.0.to_base58()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __repr__(&self) -> String {
|
|
||||||
format!("PeerId({})", self.to_base58())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __str__(&self) -> String {
|
|
||||||
self.to_base58()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
||||||
m.add_class::<PyKeypair>()?;
|
|
||||||
m.add_class::<PyPeerId>()?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
|
||||||
//!
|
|
||||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
|
||||||
//! independent identity type of some kind or another. This may require handshaking.
|
|
||||||
//!
|
|
||||||
|
|
||||||
pub mod ident;
|
|
||||||
pub mod multiaddr;
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
use crate::ext::ResultExt as _;
|
|
||||||
use libp2p::Multiaddr;
|
|
||||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
|
||||||
use pyo3::types::PyBytes;
|
|
||||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
|
||||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
|
||||||
use std::str::FromStr as _;
|
|
||||||
|
|
||||||
/// Representation of a Multiaddr.
|
|
||||||
#[gen_stub_pyclass]
|
|
||||||
#[pyclass(name = "Multiaddr", frozen)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct PyMultiaddr(pub Multiaddr);
|
|
||||||
|
|
||||||
#[gen_stub_pymethods]
|
|
||||||
#[pymethods]
|
|
||||||
#[allow(clippy::needless_pass_by_value)]
|
|
||||||
impl PyMultiaddr {
|
|
||||||
/// Create a new, empty multiaddress.
|
|
||||||
#[staticmethod]
|
|
||||||
fn empty() -> Self {
|
|
||||||
Self(Multiaddr::empty())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new, empty multiaddress with the given capacity.
|
|
||||||
#[staticmethod]
|
|
||||||
fn with_capacity(n: usize) -> Self {
|
|
||||||
Self(Multiaddr::with_capacity(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse a `Multiaddr` value from its string representation.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_string(string: String) -> PyResult<Self> {
|
|
||||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the length in bytes of this multiaddress.
|
|
||||||
fn len(&self) -> usize {
|
|
||||||
self.0.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if the length of this multiaddress is 0.
|
|
||||||
fn is_empty(&self) -> bool {
|
|
||||||
self.0.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
|
||||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
|
||||||
let bytes = self.0.to_vec();
|
|
||||||
PyBytes::new(py, &bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert a Multiaddr to a string.
|
|
||||||
fn to_string(&self) -> String {
|
|
||||||
self.0.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gen_stub(skip)]
|
|
||||||
fn __repr__(&self) -> String {
|
|
||||||
format!("Multiaddr({})", self.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gen_stub(skip)]
|
|
||||||
fn __str__(&self) -> String {
|
|
||||||
self.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
||||||
m.add_class::<PyMultiaddr>()?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@@ -22,7 +22,7 @@ delegate = { workspace = true }
|
|||||||
|
|
||||||
# async
|
# async
|
||||||
tokio = { workspace = true, features = ["full"] }
|
tokio = { workspace = true, features = ["full"] }
|
||||||
futures = { workspace = true }
|
futures-lite = { workspace = true }
|
||||||
futures-timer = { workspace = true }
|
futures-timer = { workspace = true }
|
||||||
|
|
||||||
# utility dependencies
|
# utility dependencies
|
||||||
@@ -35,3 +35,4 @@ log = { workspace = true }
|
|||||||
|
|
||||||
# networking
|
# networking
|
||||||
libp2p = { workspace = true, features = ["full"] }
|
libp2p = { workspace = true, features = ["full"] }
|
||||||
|
pin-project = "1.1.10"
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
use futures::stream::StreamExt as _;
|
use futures_lite::StreamExt;
|
||||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
use libp2p::identity;
|
||||||
use networking::{discovery, swarm};
|
use networking::swarm;
|
||||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
use networking::swarm::{FromSwarm, ToSwarm};
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
use tokio::{io, io::AsyncBufReadExt as _};
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use tracing_subscriber::filter::LevelFilter;
|
use tracing_subscriber::filter::LevelFilter;
|
||||||
|
|
||||||
@@ -11,64 +13,68 @@ async fn main() {
|
|||||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||||
.try_init();
|
.try_init();
|
||||||
|
|
||||||
|
let (to_swarm, from_client) = mpsc::channel(20);
|
||||||
|
|
||||||
// Configure swarm
|
// Configure swarm
|
||||||
let mut swarm =
|
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
|
||||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
.expect("Swarm creation failed");
|
||||||
|
|
||||||
// Create a Gossipsub topic & subscribe
|
// Create a Gossipsub topic & subscribe
|
||||||
let topic = gossipsub::IdentTopic::new("test-net");
|
let (tx, rx) = oneshot::channel();
|
||||||
swarm
|
_ = to_swarm
|
||||||
.behaviour_mut()
|
.send(ToSwarm::Subscribe {
|
||||||
.gossipsub
|
topic: "test-net".to_string(),
|
||||||
.subscribe(&topic)
|
result_sender: tx,
|
||||||
.expect("Subscribing to topic failed");
|
})
|
||||||
|
.await
|
||||||
|
.expect("should send");
|
||||||
|
|
||||||
// Read full lines from stdin
|
// Read full lines from stdin
|
||||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||||
|
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
rx.await
|
||||||
|
.expect("tx not dropped")
|
||||||
|
.expect("subscribe shouldn't fail");
|
||||||
|
loop {
|
||||||
|
if let Ok(Some(line)) = stdin.next_line().await {
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
if let Err(e) = to_swarm
|
||||||
|
.send(swarm::ToSwarm::Publish {
|
||||||
|
topic: "test-net".to_string(),
|
||||||
|
data: line.as_bytes().to_vec(),
|
||||||
|
result_sender: tx,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
println!("Send error: {e:?}");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
match rx.await {
|
||||||
|
Ok(Err(e)) => println!("Publish error: {e:?}"),
|
||||||
|
Err(e) => println!("Publish error: {e:?}"),
|
||||||
|
Ok(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Kick it off
|
// Kick it off
|
||||||
loop {
|
loop {
|
||||||
select! {
|
// on gossipsub outgoing
|
||||||
// on gossipsub outgoing
|
match swarm.next().await {
|
||||||
Ok(Some(line)) = stdin.next_line() => {
|
// on gossipsub incoming
|
||||||
if let Err(e) = swarm
|
Some(FromSwarm::Discovered { peer_id }) => {
|
||||||
.behaviour_mut().gossipsub
|
println!("\n\nconnected to {peer_id}\n\n")
|
||||||
.publish(topic.clone(), line.as_bytes()) {
|
|
||||||
println!("Publish error: {e:?}");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
event = swarm.select_next_some() => match event {
|
Some(FromSwarm::Expired { peer_id }) => {
|
||||||
// on gossipsub incoming
|
println!("\n\ndisconnected from {peer_id}\n\n")
|
||||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
|
||||||
propagation_source: peer_id,
|
|
||||||
message_id: id,
|
|
||||||
message,
|
|
||||||
})) => println!(
|
|
||||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
|
||||||
String::from_utf8_lossy(&message.data),
|
|
||||||
),
|
|
||||||
|
|
||||||
// on discovery
|
|
||||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
|
||||||
discovery::Event::ConnectionEstablished {
|
|
||||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
|
||||||
} => {
|
|
||||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
|
||||||
}
|
|
||||||
discovery::Event::ConnectionClosed {
|
|
||||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
|
||||||
} => {
|
|
||||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ignore outgoing errors: those are normal
|
|
||||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
|
||||||
|
|
||||||
// otherwise log any other event
|
|
||||||
e => { log::info!("Other event {e:?}"); }
|
|
||||||
}
|
}
|
||||||
|
Some(FromSwarm::Message { from, topic, data }) => {
|
||||||
|
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
|
||||||
//
|
|
||||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
|
||||||
// copy of this software and associated documentation files (the "Software"),
|
|
||||||
// to deal in the Software without restriction, including without limitation
|
|
||||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
||||||
// and/or sell copies of the Software, and to permit persons to whom the
|
|
||||||
// Software is furnished to do so, subject to the following conditions:
|
|
||||||
//
|
|
||||||
// The above copyright notice and this permission notice shall be included in
|
|
||||||
// all copies or substantial portions of the Software.
|
|
||||||
//
|
|
||||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
||||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
||||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
||||||
// DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
use futures::stream::StreamExt;
|
|
||||||
use libp2p::{
|
|
||||||
gossipsub, mdns, noise,
|
|
||||||
swarm::{NetworkBehaviour, SwarmEvent},
|
|
||||||
tcp, yamux,
|
|
||||||
};
|
|
||||||
use std::error::Error;
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::{io, io::AsyncBufReadExt, select};
|
|
||||||
use tracing_subscriber::EnvFilter;
|
|
||||||
|
|
||||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
|
||||||
#[derive(NetworkBehaviour)]
|
|
||||||
struct MyBehaviour {
|
|
||||||
gossipsub: gossipsub::Behaviour,
|
|
||||||
mdns: mdns::tokio::Behaviour,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), Box<dyn Error>> {
|
|
||||||
let _ = tracing_subscriber::fmt()
|
|
||||||
.with_env_filter(EnvFilter::from_default_env())
|
|
||||||
.try_init();
|
|
||||||
|
|
||||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
|
||||||
.with_tokio()
|
|
||||||
.with_tcp(
|
|
||||||
tcp::Config::default(),
|
|
||||||
noise::Config::new,
|
|
||||||
yamux::Config::default,
|
|
||||||
)?
|
|
||||||
.with_behaviour(|key| {
|
|
||||||
// Set a custom gossipsub configuration
|
|
||||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
|
||||||
.heartbeat_interval(Duration::from_secs(10))
|
|
||||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
|
||||||
.build()
|
|
||||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
|
||||||
|
|
||||||
// build a gossipsub network behaviour
|
|
||||||
let gossipsub = gossipsub::Behaviour::new(
|
|
||||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
|
||||||
gossipsub_config,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mdns =
|
|
||||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
|
||||||
Ok(MyBehaviour { gossipsub, mdns })
|
|
||||||
})?
|
|
||||||
.build();
|
|
||||||
|
|
||||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
|
||||||
|
|
||||||
// Create a Gossipsub topic
|
|
||||||
let topic = gossipsub::IdentTopic::new("test-net");
|
|
||||||
// subscribes to our topic
|
|
||||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
|
||||||
|
|
||||||
// Read full lines from stdin
|
|
||||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
|
||||||
|
|
||||||
// Listen on all interfaces and whatever port the OS assigns
|
|
||||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
|
||||||
|
|
||||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
|
||||||
|
|
||||||
// Kick it off
|
|
||||||
loop {
|
|
||||||
select! {
|
|
||||||
Ok(Some(line)) = stdin.next_line() => {
|
|
||||||
if let Err(e) = swarm
|
|
||||||
.behaviour_mut().gossipsub
|
|
||||||
.publish(topic.clone(), line.as_bytes()) {
|
|
||||||
println!("Publish error: {e:?}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
event = swarm.select_next_some() => match event {
|
|
||||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
|
||||||
for (peer_id, multiaddr) in list {
|
|
||||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
|
||||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
|
||||||
for (peer_id, multiaddr) in list {
|
|
||||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
|
||||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
|
||||||
propagation_source: peer_id,
|
|
||||||
message_id: id,
|
|
||||||
message,
|
|
||||||
})) => println!(
|
|
||||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
|
||||||
String::from_utf8_lossy(&message.data),
|
|
||||||
),
|
|
||||||
SwarmEvent::NewListenAddr { address, .. } => {
|
|
||||||
println!("Local node is listening on {address}");
|
|
||||||
}
|
|
||||||
e => {
|
|
||||||
println!("Other swarm event: {:?}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::ext::MultiaddrExt;
|
use crate::ext::MultiaddrExt;
|
||||||
use delegate::delegate;
|
use delegate::delegate;
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use futures::FutureExt;
|
use futures_lite::FutureExt;
|
||||||
use futures_timer::Delay;
|
use futures_timer::Delay;
|
||||||
use libp2p::core::transport::PortUse;
|
use libp2p::core::transport::PortUse;
|
||||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||||
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
if self.retry_delay.poll(cx).is_ready() {
|
||||||
for (p, mas) in self.mdns_discovered.clone() {
|
for (p, mas) in self.mdns_discovered.clone() {
|
||||||
for ma in mas {
|
for ma in mas {
|
||||||
self.dial(p, ma)
|
self.dial(p, ma)
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
use crate::alias;
|
use std::pin::Pin;
|
||||||
use crate::swarm::transport::tcp_transport;
|
use std::task::Poll;
|
||||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
|
||||||
use libp2p::{SwarmBuilder, identity};
|
|
||||||
|
|
||||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
use crate::swarm::transport::tcp_transport;
|
||||||
|
use crate::{alias, discovery};
|
||||||
|
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||||
|
use futures_lite::Stream;
|
||||||
|
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
/// The current version of the network: this prevents devices running different versions of the
|
/// The current version of the network: this prevents devices running different versions of the
|
||||||
/// software from interacting with each other.
|
/// software from interacting with each other.
|
||||||
@@ -15,8 +18,144 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
|
|||||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||||
|
|
||||||
|
pub enum ToSwarm {
|
||||||
|
Unsubscribe {
|
||||||
|
topic: String,
|
||||||
|
result_sender: oneshot::Sender<bool>,
|
||||||
|
},
|
||||||
|
Subscribe {
|
||||||
|
topic: String,
|
||||||
|
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
|
||||||
|
},
|
||||||
|
Publish {
|
||||||
|
topic: String,
|
||||||
|
data: Vec<u8>,
|
||||||
|
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pub enum FromSwarm {
|
||||||
|
Message {
|
||||||
|
from: PeerId,
|
||||||
|
topic: String,
|
||||||
|
data: Vec<u8>,
|
||||||
|
},
|
||||||
|
Discovered {
|
||||||
|
peer_id: PeerId,
|
||||||
|
},
|
||||||
|
Expired {
|
||||||
|
peer_id: PeerId,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
#[pin_project::pin_project]
|
||||||
|
pub struct Swarm {
|
||||||
|
#[pin]
|
||||||
|
inner: libp2p::Swarm<Behaviour>,
|
||||||
|
from_client: mpsc::Receiver<ToSwarm>,
|
||||||
|
}
|
||||||
|
impl Swarm {
|
||||||
|
fn on_message(mut self: Pin<&mut Self>, message: ToSwarm) {
|
||||||
|
match message {
|
||||||
|
ToSwarm::Subscribe {
|
||||||
|
topic,
|
||||||
|
result_sender,
|
||||||
|
} => {
|
||||||
|
// try to subscribe
|
||||||
|
let result = self
|
||||||
|
.inner
|
||||||
|
.behaviour_mut()
|
||||||
|
.gossipsub
|
||||||
|
.subscribe(&gossipsub::IdentTopic::new(topic));
|
||||||
|
|
||||||
|
// send response oneshot
|
||||||
|
_ = result_sender.send(result)
|
||||||
|
}
|
||||||
|
ToSwarm::Unsubscribe {
|
||||||
|
topic,
|
||||||
|
result_sender,
|
||||||
|
} => {
|
||||||
|
// try to unsubscribe from the topic
|
||||||
|
let result = self
|
||||||
|
.inner
|
||||||
|
.behaviour_mut()
|
||||||
|
.gossipsub
|
||||||
|
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||||
|
|
||||||
|
// send response oneshot (or exit if connection closed)
|
||||||
|
_ = result_sender.send(result)
|
||||||
|
}
|
||||||
|
ToSwarm::Publish {
|
||||||
|
topic,
|
||||||
|
data,
|
||||||
|
result_sender,
|
||||||
|
} => {
|
||||||
|
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||||
|
let result = self
|
||||||
|
.inner
|
||||||
|
.behaviour_mut()
|
||||||
|
.gossipsub
|
||||||
|
.publish(gossipsub::IdentTopic::new(topic), data);
|
||||||
|
// send response oneshot (or exit if connection closed)
|
||||||
|
_ = result_sender.send(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Stream for Swarm {
|
||||||
|
type Item = FromSwarm;
|
||||||
|
fn poll_next(
|
||||||
|
mut self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Option<Self::Item>> {
|
||||||
|
loop {
|
||||||
|
let recv = self.as_mut().project().from_client;
|
||||||
|
match recv.poll_recv(cx) {
|
||||||
|
Poll::Ready(Some(msg)) => {
|
||||||
|
self.as_mut().on_message(msg);
|
||||||
|
// continue to re-poll after consumption
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Poll::Ready(None) => return Poll::Ready(None),
|
||||||
|
Poll::Pending => {}
|
||||||
|
}
|
||||||
|
let inner = self.as_mut().project().inner;
|
||||||
|
return match inner.poll_next(cx) {
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Ready(Some(swarm_event)) => match swarm_event {
|
||||||
|
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(
|
||||||
|
gossipsub::Event::Message {
|
||||||
|
message:
|
||||||
|
gossipsub::Message {
|
||||||
|
source: Some(peer_id),
|
||||||
|
topic,
|
||||||
|
data,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
..
|
||||||
|
},
|
||||||
|
)) => Poll::Ready(Some(FromSwarm::Message {
|
||||||
|
from: peer_id,
|
||||||
|
topic: topic.into_string(),
|
||||||
|
data,
|
||||||
|
})),
|
||||||
|
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||||
|
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||||
|
)) => Poll::Ready(Some(FromSwarm::Discovered { peer_id })),
|
||||||
|
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||||
|
discovery::Event::ConnectionClosed { peer_id, .. },
|
||||||
|
)) => Poll::Ready(Some(FromSwarm::Expired { peer_id })),
|
||||||
|
// continue to re-poll after consumption
|
||||||
|
_ => continue,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
/// Create and configure a swarm which listens to all ports on OS
|
/// Create and configure a swarm which listens to all ports on OS
|
||||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
pub fn create_swarm(
|
||||||
|
keypair: identity::Keypair,
|
||||||
|
from_client: mpsc::Receiver<ToSwarm>,
|
||||||
|
) -> alias::AnyResult<Swarm> {
|
||||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||||
.with_tokio()
|
.with_tokio()
|
||||||
.with_other_transport(tcp_transport)?
|
.with_other_transport(tcp_transport)?
|
||||||
@@ -25,13 +164,16 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
|||||||
|
|
||||||
// Listen on all interfaces and whatever port the OS assigns
|
// Listen on all interfaces and whatever port the OS assigns
|
||||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||||
Ok(swarm)
|
Ok(Swarm {
|
||||||
|
inner: swarm,
|
||||||
|
from_client,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
mod transport {
|
mod transport {
|
||||||
use crate::alias;
|
use crate::alias;
|
||||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||||
use futures::{AsyncRead, AsyncWrite};
|
use futures_lite::{AsyncRead, AsyncWrite};
|
||||||
use keccak_const::Sha3_256;
|
use keccak_const::Sha3_256;
|
||||||
use libp2p::core::muxing;
|
use libp2p::core::muxing;
|
||||||
use libp2p::core::transport::Boxed;
|
use libp2p::core::transport::Boxed;
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
{ inputs, ... }:
|
{ inputs, ... }:
|
||||||
{
|
{
|
||||||
perSystem =
|
perSystem =
|
||||||
{ config, self', inputs', pkgs, lib, ... }:
|
{ inputs', pkgs, lib, ... }:
|
||||||
let
|
let
|
||||||
# Fenix nightly toolchain with all components
|
# Fenix nightly toolchain with all components
|
||||||
fenixPkgs = inputs'.fenix.packages;
|
rustToolchain = inputs'.fenix.packages.stable.withComponents [
|
||||||
rustToolchain = fenixPkgs.complete.withComponents [
|
|
||||||
"cargo"
|
"cargo"
|
||||||
"rustc"
|
"rustc"
|
||||||
"clippy"
|
"clippy"
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
[toolchain]
|
|
||||||
channel = "nightly"
|
|
||||||
@@ -45,7 +45,7 @@ class Node:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, args: "Args") -> "Self":
|
async def create(cls, args: "Args") -> "Self":
|
||||||
keypair = get_node_id_keypair()
|
keypair = get_node_id_keypair()
|
||||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
node_id = NodeId(keypair.to_node_id())
|
||||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||||
router = Router.create(keypair)
|
router = Router.create(keypair)
|
||||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseOutputText,
|
ResponseOutputText,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
|
ResponsesStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
|
|||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||||
|
"""Format a streaming event as an SSE message."""
|
||||||
|
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -219,13 +225,13 @@ async def generate_responses_stream(
|
|||||||
created_event = ResponseCreatedEvent(
|
created_event = ResponseCreatedEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
yield _format_sse(created_event)
|
||||||
|
|
||||||
# response.in_progress
|
# response.in_progress
|
||||||
in_progress_event = ResponseInProgressEvent(
|
in_progress_event = ResponseInProgressEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
yield _format_sse(in_progress_event)
|
||||||
|
|
||||||
# response.output_item.added
|
# response.output_item.added
|
||||||
initial_item = ResponseMessageItem(
|
initial_item = ResponseMessageItem(
|
||||||
@@ -236,7 +242,7 @@ async def generate_responses_stream(
|
|||||||
item_added = ResponseOutputItemAddedEvent(
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=initial_item
|
sequence_number=next(seq), output_index=0, item=initial_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
yield _format_sse(item_added)
|
||||||
|
|
||||||
# response.content_part.added
|
# response.content_part.added
|
||||||
initial_part = ResponseOutputText(text="")
|
initial_part = ResponseOutputText(text="")
|
||||||
@@ -247,7 +253,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=initial_part,
|
part=initial_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
yield _format_sse(part_added)
|
||||||
|
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
@@ -281,7 +287,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_item,
|
item=fc_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
yield _format_sse(fc_added)
|
||||||
|
|
||||||
# response.function_call_arguments.delta
|
# response.function_call_arguments.delta
|
||||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
@@ -290,7 +296,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
delta=tool.arguments,
|
delta=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
yield _format_sse(args_delta)
|
||||||
|
|
||||||
# response.function_call_arguments.done
|
# response.function_call_arguments.done
|
||||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||||
@@ -300,7 +306,7 @@ async def generate_responses_stream(
|
|||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
yield _format_sse(args_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
fc_done_item = ResponseFunctionCallItem(
|
fc_done_item = ResponseFunctionCallItem(
|
||||||
@@ -315,7 +321,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_done_item,
|
item=fc_done_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
yield _format_sse(fc_item_done)
|
||||||
|
|
||||||
function_call_items.append(fc_done_item)
|
function_call_items.append(fc_done_item)
|
||||||
next_output_index += 1
|
next_output_index += 1
|
||||||
@@ -331,7 +337,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
delta=chunk.text,
|
delta=chunk.text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
yield _format_sse(delta_event)
|
||||||
|
|
||||||
# response.output_text.done
|
# response.output_text.done
|
||||||
text_done = ResponseTextDoneEvent(
|
text_done = ResponseTextDoneEvent(
|
||||||
@@ -341,7 +347,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
text=accumulated_text,
|
text=accumulated_text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
yield _format_sse(text_done)
|
||||||
|
|
||||||
# response.content_part.done
|
# response.content_part.done
|
||||||
final_part = ResponseOutputText(text=accumulated_text)
|
final_part = ResponseOutputText(text=accumulated_text)
|
||||||
@@ -352,7 +358,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=final_part,
|
part=final_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
yield _format_sse(part_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
final_message_item = ResponseMessageItem(
|
final_message_item = ResponseMessageItem(
|
||||||
@@ -363,7 +369,7 @@ async def generate_responses_stream(
|
|||||||
item_done = ResponseOutputItemDoneEvent(
|
item_done = ResponseOutputItemDoneEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
yield _format_sse(item_done)
|
||||||
|
|
||||||
# Create usage from usage data if available
|
# Create usage from usage data if available
|
||||||
usage = None
|
usage = None
|
||||||
@@ -388,4 +394,4 @@ async def generate_responses_stream(
|
|||||||
completed_event = ResponseCompletedEvent(
|
completed_event = ResponseCompletedEvent(
|
||||||
sequence_number=next(seq), response=final_response
|
sequence_number=next(seq), response=final_response
|
||||||
)
|
)
|
||||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
yield _format_sse(completed_event)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from exo.utils.channels import channel
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_master():
|
async def test_master():
|
||||||
keypair = get_node_id_keypair()
|
keypair = get_node_id_keypair()
|
||||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
node_id = NodeId(keypair.to_node_id())
|
||||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||||
|
|
||||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||||
@@ -75,7 +75,7 @@ async def test_master():
|
|||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(master.run)
|
tg.start_soon(master.run)
|
||||||
|
|
||||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
|
||||||
# inject a NodeGatheredInfo event
|
# inject a NodeGatheredInfo event
|
||||||
logger.info("inject a NodeGatheredInfo event")
|
logger.info("inject a NodeGatheredInfo event")
|
||||||
await local_event_sender.send(
|
await local_event_sender.send(
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from enum import Enum
|
from exo_pyo3_bindings import PyFromSwarm
|
||||||
|
|
||||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
|
||||||
|
|
||||||
from exo.shared.types.common import NodeId
|
from exo.shared.types.common import NodeId
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
|||||||
"""Serialisable types for Connection Updates/Messages"""
|
"""Serialisable types for Connection Updates/Messages"""
|
||||||
|
|
||||||
|
|
||||||
class ConnectionMessageType(Enum):
|
|
||||||
Connected = 0
|
|
||||||
Disconnected = 1
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_update_type(update_type: ConnectionUpdateType):
|
|
||||||
match update_type:
|
|
||||||
case ConnectionUpdateType.Connected:
|
|
||||||
return ConnectionMessageType.Connected
|
|
||||||
case ConnectionUpdateType.Disconnected:
|
|
||||||
return ConnectionMessageType.Disconnected
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionMessage(CamelCaseModel):
|
class ConnectionMessage(CamelCaseModel):
|
||||||
node_id: NodeId
|
node_id: NodeId
|
||||||
connection_type: ConnectionMessageType
|
connected: bool
|
||||||
remote_ipv4: str
|
|
||||||
remote_tcp_port: int
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
|
||||||
return cls(
|
return cls(node_id=NodeId(update.peer_id), connected=update.connected)
|
||||||
node_id=NodeId(update.peer_id.to_base58()),
|
|
||||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
|
||||||
remote_ipv4=update.remote_ipv4,
|
|
||||||
remote_tcp_port=update.remote_tcp_port,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from exo_pyo3_bindings import (
|
|||||||
Keypair,
|
Keypair,
|
||||||
NetworkingHandle,
|
NetworkingHandle,
|
||||||
NoPeersSubscribedToTopicError,
|
NoPeersSubscribedToTopicError,
|
||||||
|
PyFromSwarm,
|
||||||
)
|
)
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -114,7 +115,6 @@ class Router:
|
|||||||
self._tg: TaskGroup | None = None
|
self._tg: TaskGroup | None = None
|
||||||
|
|
||||||
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
||||||
assert self._tg is None, "Attempted to register topic after setup time"
|
|
||||||
send = self._tmp_networking_sender
|
send = self._tmp_networking_sender
|
||||||
if send:
|
if send:
|
||||||
self._tmp_networking_sender = None
|
self._tmp_networking_sender = None
|
||||||
@@ -122,7 +122,8 @@ class Router:
|
|||||||
send = self.networking_receiver.clone_sender()
|
send = self.networking_receiver.clone_sender()
|
||||||
router = TopicRouter[T](topic, send)
|
router = TopicRouter[T](topic, send)
|
||||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||||
await self._networking_subscribe(str(topic.topic))
|
if self._tg is not None:
|
||||||
|
await self._networking_subscribe(topic.topic)
|
||||||
|
|
||||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||||
router = self.topic_routers.get(topic.topic, None)
|
router = self.topic_routers.get(topic.topic, None)
|
||||||
@@ -154,8 +155,10 @@ class Router:
|
|||||||
router = self.topic_routers[topic]
|
router = self.topic_routers[topic]
|
||||||
tg.start_soon(router.run)
|
tg.start_soon(router.run)
|
||||||
tg.start_soon(self._networking_recv)
|
tg.start_soon(self._networking_recv)
|
||||||
tg.start_soon(self._networking_recv_connection_messages)
|
|
||||||
tg.start_soon(self._networking_publish)
|
tg.start_soon(self._networking_publish)
|
||||||
|
# subscribe to pending topics
|
||||||
|
for topic in self.topic_routers:
|
||||||
|
await self._networking_subscribe(topic)
|
||||||
# Router only shuts down if you cancel it.
|
# Router only shuts down if you cancel it.
|
||||||
await sleep_forever()
|
await sleep_forever()
|
||||||
finally:
|
finally:
|
||||||
@@ -179,27 +182,33 @@ class Router:
|
|||||||
|
|
||||||
async def _networking_recv(self):
|
async def _networking_recv(self):
|
||||||
while True:
|
while True:
|
||||||
topic, data = await self._net.gossipsub_recv()
|
from_swarm = await self._net.recv()
|
||||||
logger.trace(f"Received message on {topic} with payload {data}")
|
logger.debug(from_swarm)
|
||||||
if topic not in self.topic_routers:
|
match from_swarm:
|
||||||
logger.warning(f"Received message on unknown or inactive topic {topic}")
|
case PyFromSwarm.Message(origin, topic, data):
|
||||||
continue
|
logger.trace(
|
||||||
|
f"Received message on {topic} from {origin} with payload {data}"
|
||||||
|
)
|
||||||
|
if topic not in self.topic_routers:
|
||||||
|
logger.warning(
|
||||||
|
f"Received message on unknown or inactive topic {topic}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
router = self.topic_routers[topic]
|
router = self.topic_routers[topic]
|
||||||
await router.publish_bytes(data)
|
await router.publish_bytes(data)
|
||||||
|
case PyFromSwarm.Connection():
|
||||||
async def _networking_recv_connection_messages(self):
|
message = ConnectionMessage.from_update(from_swarm)
|
||||||
while True:
|
logger.trace(
|
||||||
update = await self._net.connection_update_recv()
|
f"Received message on connection_messages with payload {message}"
|
||||||
message = ConnectionMessage.from_update(update)
|
)
|
||||||
logger.trace(
|
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||||
f"Received message on connection_messages with payload {message}"
|
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||||
)
|
assert router.topic.model_type == ConnectionMessage
|
||||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
router = cast(TopicRouter[ConnectionMessage], router)
|
||||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
await router.publish(message)
|
||||||
assert router.topic.model_type == ConnectionMessage
|
case _:
|
||||||
router = cast(TopicRouter[ConnectionMessage], router)
|
raise AssertionError("exhaustive net messages have been checked")
|
||||||
await router.publish(message)
|
|
||||||
|
|
||||||
async def _networking_publish(self):
|
async def _networking_publish(self):
|
||||||
with self.networking_receiver as networked_items:
|
with self.networking_receiver as networked_items:
|
||||||
@@ -221,7 +230,7 @@ def get_node_id_keypair(
|
|||||||
Obtain the :class:`PeerId` by from it.
|
Obtain the :class:`PeerId` by from it.
|
||||||
"""
|
"""
|
||||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||||
return Keypair.generate_ed25519()
|
return Keypair.generate()
|
||||||
|
|
||||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||||
return Path(str(path) + ".lock")
|
return Path(str(path) + ".lock")
|
||||||
@@ -235,12 +244,12 @@ def get_node_id_keypair(
|
|||||||
protobuf_encoded = f.read()
|
protobuf_encoded = f.read()
|
||||||
|
|
||||||
try: # if decoded successfully, save & return
|
try: # if decoded successfully, save & return
|
||||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
return Keypair.from_bytes(protobuf_encoded)
|
||||||
except ValueError as e: # on runtime error, assume corrupt file
|
except ValueError as e: # on runtime error, assume corrupt file
|
||||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||||
|
|
||||||
# if no valid credentials, create new ones and persist
|
# if no valid credentials, create new ones and persist
|
||||||
with open(path, "w+b") as f:
|
with open(path, "w+b") as f:
|
||||||
keypair = Keypair.generate_ed25519()
|
keypair = Keypair.generate_ed25519()
|
||||||
f.write(keypair.to_protobuf_encoding())
|
f.write(keypair.to_bytes())
|
||||||
return keypair
|
return keypair
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from anyio import create_task_group, fail_after, move_on_after
|
from anyio import create_task_group, fail_after, move_on_after
|
||||||
|
|
||||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
from exo.routing.connection_message import ConnectionMessage
|
||||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||||
from exo.shared.types.common import NodeId, SessionId
|
from exo.shared.types.common import NodeId, SessionId
|
||||||
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
|||||||
tg.start_soon(election.run)
|
tg.start_soon(election.run)
|
||||||
|
|
||||||
# Send any connection message object; we close quickly to cancel before result creation
|
# Send any connection message object; we close quickly to cancel before result creation
|
||||||
await cm_tx.send(
|
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
|
||||||
ConnectionMessage(
|
|
||||||
node_id=NodeId(),
|
|
||||||
connection_type=ConnectionMessageType.Connected,
|
|
||||||
remote_ipv4="",
|
|
||||||
remote_tcp_port=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expect a broadcast for the new round at clock=1
|
# Expect a broadcast for the new round at clock=1
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
|||||||
sem.release()
|
sem.release()
|
||||||
# wait to be told to begin simultaneous read
|
# wait to be told to begin simultaneous read
|
||||||
ev.wait()
|
ev.wait()
|
||||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
queue.put(get_node_id_keypair().to_bytes())
|
||||||
|
|
||||||
|
|
||||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||||
|
|||||||
@@ -241,6 +241,11 @@ class Worker:
|
|||||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||||
):
|
):
|
||||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||||
|
await self.event_sender.send(
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||||
|
)
|
||||||
|
)
|
||||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||||
# Assemble image from chunks and inject into task
|
# Assemble image from chunks and inject into task
|
||||||
cmd_id = task.command_id
|
cmd_id = task.command_id
|
||||||
|
|||||||
Reference in New Issue
Block a user